Compare commits

...

150 Commits

Author SHA1 Message Date
Josh Yan
b6c7d01af3 more cmt rmv 2024-07-11 17:21:36 -07:00
Josh Yan
9d517cf556 rm comment 2024-07-11 17:20:09 -07:00
Josh Yan
6bab0e2368 lint 2024-07-10 12:36:32 -07:00
Josh Yan
c4cccaf936 remove rebase err 2024-07-10 11:37:55 -07:00
Josh Yan
9fe5c393e4 hi 2024-07-10 11:35:01 -07:00
Josh Yan
007c988dba rmv double msg 2024-07-10 11:34:31 -07:00
Josh Yan
91d21e7c7b rmv double msg 2024-07-10 11:34:31 -07:00
Josh Yan
3e64284f69 percent 2024-07-10 11:34:31 -07:00
Josh Yan
39910f2ab2 percent 2024-07-10 11:34:29 -07:00
Josh Yan
96d0cd92f2 rebase 2024-07-10 11:31:53 -07:00
Josh Yan
3a724a7c80 isLocal firstdraft 2024-07-10 11:31:12 -07:00
Josh Yan
f520f0056e rm config 2024-07-10 11:29:51 -07:00
Josh Yan
d25f85ede4 on disk copy 2024-07-10 11:29:49 -07:00
Josh Yan
b48420b74b percent 2024-07-10 11:27:33 -07:00
Josh Yan
784958a1cb transfer data 2024-07-10 11:24:50 -07:00
Josh Yan
ae65cc8dea progress 2024-07-10 11:24:48 -07:00
Josh Yan
a037528bba lint 2024-07-10 11:20:02 -07:00
Josh Yan
04bf41deb5 clean 2024-07-10 11:20:02 -07:00
Josh Yan
c23cec9547 removed cmt and prints 2024-07-10 11:20:02 -07:00
Josh Yan
8377dc48d0 removed client isLocal() 2024-07-10 11:20:02 -07:00
Josh Yan
3aee405dfa lint 2024-07-10 11:20:02 -07:00
Josh Yan
9b3f47b674 lint 2024-07-10 11:20:02 -07:00
Josh Yan
f5441f01a2 lint 2024-07-10 11:20:02 -07:00
Josh Yan
ab165df43a syscopy windows 2024-07-10 11:20:02 -07:00
Josh Yan
79cc4c9585 os copy 2024-07-10 11:20:02 -07:00
Josh Yan
bc3f59a6ad rmv prints 2024-07-10 11:20:02 -07:00
Josh Yan
1a85cb904c local copy 2024-07-10 11:20:02 -07:00
Josh Yan
10ea0987e9 isLocal firstdraft 2024-07-10 11:19:50 -07:00
Josh Yan
413d368a6a clean 2024-07-10 11:19:32 -07:00
Josh Yan
cabf375059 rm bench 2024-07-10 11:19:32 -07:00
Josh Yan
ca0ee1d4fe rm config 2024-07-10 11:19:32 -07:00
Josh Yan
1142999aab rm config 2024-07-10 11:19:32 -07:00
Josh Yan
0d5a72aba9 clean 2024-07-10 11:19:32 -07:00
Josh Yan
ea837412c2 local path 2024-07-10 11:19:32 -07:00
Josh Yan
736ad6f438 still works 2024-07-10 11:19:32 -07:00
Josh Yan
64607d16a5 working 2024-07-10 11:19:32 -07:00
Josh Yan
a6cfe7f00b benchmark 2024-07-10 11:19:32 -07:00
Josh Yan
c3b411a515 on disk copy 2024-07-10 11:19:32 -07:00
Josh Yan
928f37e3ae start tests 2024-07-10 11:19:32 -07:00
Daniel Hiltgen
2d1e3c3229 Merge pull request #5503 from dhiltgen/dual_rocm
Workaround broken ROCm p2p copy
2024-07-09 15:44:16 -07:00
royjhan
4918fae535 OpenAI v1/completions: allow stop token list (#5551)
* stop token parsing fix

* add stop test
2024-07-09 14:01:26 -07:00
royjhan
0aff67877e separate request tests (#5578) 2024-07-09 13:48:31 -07:00
Daniel Hiltgen
9544a57ee4 Merge pull request #5579 from dhiltgen/win_static_deps
Statically link c++ and thread lib on windows
2024-07-09 12:21:13 -07:00
Daniel Hiltgen
b51e3b63ac Statically link c++ and thread lib
This makes sure we statically link the c++ and thread library on windows
to avoid unnecessary runtime dependencies on non-standard DLLs
2024-07-09 11:34:30 -07:00
Michael Yang
6bbbc50f10 Merge pull request #5440 from ollama/mxyng/messages-templates
update named templates
2024-07-09 09:36:32 -07:00
Michael Yang
9bbddc37a7 Merge pull request #5126 from ollama/mxyng/messages
update message processing
2024-07-09 09:20:44 -07:00
Jeffrey Morgan
e4ff73297d server: fix model reloads when setting OLLAMA_NUM_PARALLEL (#5560)
* server: fix unneeded model reloads when setting `OLLAMA_NUM_PARALLEL`

* remove whitespace change

* undo some changes
2024-07-08 22:32:15 -07:00
Daniel Hiltgen
0bacb30007 Workaround broken ROCm p2p copy
Enable the build flag for llama.cpp to use CPU copy for multi-GPU scenarios.
2024-07-08 09:40:52 -07:00
Jeffrey Morgan
53da2c6965 llm: remove ambiguous comment when putting upper limit on predictions to avoid infinite generation (#5535) 2024-07-07 14:32:05 -04:00
Jeffrey Morgan
d8def1ff94 llm: allow gemma 2 to context shift (#5534) 2024-07-07 13:41:51 -04:00
Jeffrey Morgan
571dc61955 Update llama.cpp submodule to a8db2a9c (#5530) 2024-07-07 13:03:09 -04:00
Jeffrey Morgan
0e09c380fc llm: print caching notices in debug only (#5533) 2024-07-07 12:38:04 -04:00
Jeffrey Morgan
0ee87615c7 sched: don't error if paging to disk on Windows and macOS (#5523) 2024-07-06 22:01:52 -04:00
Jeffrey Morgan
f8241bfba3 gpu: report system free memory instead of 0 (#5521) 2024-07-06 19:35:04 -04:00
Jeffrey Morgan
4607c70641 llm: add -DBUILD_SHARED_LIBS=off to common cpu cmake flags (#5520) 2024-07-06 18:58:16 -04:00
jmorganca
c12f1c5b99 release: move mingw library cleanup to correct job 2024-07-06 16:12:29 -04:00
jmorganca
a08f20d910 release: remove unwanted mingw dll.a files 2024-07-06 15:21:15 -04:00
jmorganca
6cea036027 Revert "llm: only statically link libstdc++"
This reverts commit 5796bfc401.
2024-07-06 15:10:48 -04:00
jmorganca
5796bfc401 llm: only statically link libstdc++ 2024-07-06 14:06:20 -04:00
jmorganca
f1a379aa56 llm: statically link pthread and stdc++ dependencies in windows build 2024-07-06 12:54:02 -04:00
jmorganca
9ae146993e llm: add GGML_STATIC flag to windows static lib 2024-07-06 03:27:05 -04:00
Jeffrey Morgan
e0348d3fe8 llm: add COMMON_DARWIN_DEFS to arm static build (#5513) 2024-07-05 22:42:42 -04:00
Jeffrey Morgan
2cc854f8cb llm: fix missing dylibs by restoring old build behavior on Linux and macOS (#5511)
* Revert "fix cmake build (#5505)"

This reverts commit 4fd5f3526a.

* llm: fix missing dylibs by restoring old build behavior

* crlf -> lf
2024-07-05 21:48:31 -04:00
Jeffrey Morgan
5304b765b2 llm: put back old include dir (#5507)
* llm: put back old include dir

* llm: update link paths for old submodule commits
2024-07-05 19:34:21 -04:00
Michael Yang
fb6cbc02fb update named templates 2024-07-05 16:29:32 -07:00
Jeffrey Morgan
4fd5f3526a fix cmake build (#5505) 2024-07-05 19:07:01 -04:00
Daniel Hiltgen
842f85f758 Merge pull request #5502 from dhiltgen/ci_fixes
Always go build in CI generate steps
2024-07-05 15:39:11 -07:00
Daniel Hiltgen
9d30f9f8b3 Always go build in CI generate steps
With the recent cgo changes, bugs can sneak through
if we don't make sure to `go build` all the permutations
2024-07-05 15:31:52 -07:00
Blake Mizerany
631cfd9e62 types/model: remove knowledge of digest (#5500)
This was leading to ambiguity and confusion in ollama.com, and is not
used anywhere in ollama at the moment. Once manifests are addressable by
digest, we can add this back in, and in a way that is more tailored to
the concept of addressing a manifest by digest.
2024-07-05 13:42:30 -07:00
Michael Yang
326363b3a7 no funcs 2024-07-05 13:17:25 -07:00
Michael Yang
ac7a842e55 fix model reloading
ensure runtime model changes (template, system prompt, messages,
options) are captured on model updates without needing to reload the
server
2024-07-05 13:17:25 -07:00
Michael Yang
2c3fe1fd97 comments 2024-07-05 13:17:24 -07:00
Michael Yang
269ed6e6a2 update message processing 2024-07-05 13:16:58 -07:00
Jeffrey Morgan
78fb33dd07 fix typo in cgo directives in llm.go (#5501) 2024-07-05 15:18:36 -04:00
Jeffrey Morgan
8f8e736b13 update llama.cpp submodule to d7fd29f (#5475) 2024-07-05 13:25:58 -04:00
Jeffrey Morgan
d89454de80 Use slot with cached prompt instead of least recently used (#5492)
* Use common prefix to select slot

* actually report `longest`
2024-07-05 12:32:47 -04:00
Daniel Hiltgen
af28b94533 Merge pull request #5469 from dhiltgen/prevent_system_oom
Prevent loading models larger than total memory
2024-07-05 08:22:20 -07:00
Jeffrey Morgan
e9188e971a Fix assert on small embedding inputs (#5491)
* Fix assert on small embedding inputs

* Update llm/patches/09-pooling.diff
2024-07-05 11:20:57 -04:00
Daniel Hiltgen
78eddfc068 Merge pull request #4412 from dhiltgen/win_docs
Document older win10 terminal problems
2024-07-05 08:18:22 -07:00
Daniel Hiltgen
02c24d3d01 Merge pull request #5466 from dhiltgen/fix_clip_unicode
Fix clip model loading with unicode paths
2024-07-05 08:16:58 -07:00
Daniel Hiltgen
52abc8acb7 Document older win10 terminal problems
We haven't found a workaround, so for now recommend updating.
2024-07-03 17:32:14 -07:00
Jeffrey Morgan
4d71c559b2 fix error detection by limiting model loading error parsing (#5472) 2024-07-03 20:04:30 -04:00
Anatoli Babenia
0d16eb310e fix: use envconfig.ModelsDir directly (#4821)
* Co-authored-by: Anatoli Babenia <anatoli@rainforce.org>

Co-authored-by: Maas Lalani <maas@lalani.dev>
2024-07-03 15:36:11 -07:00
Daniel Hiltgen
8072e205ff Merge pull request #5447 from dhiltgen/fix_keepalive
Only set default keep_alive on initial model load
2024-07-03 15:34:38 -07:00
Daniel Hiltgen
955f2a4e03 Only set default keep_alive on initial model load
This change fixes the handling of keep_alive so that if client
request omits the setting, we only set this on initial load.  Once
the model is loaded, if new requests leave this unset, we'll keep
whatever keep_alive was there.
2024-07-03 15:29:56 -07:00
Daniel Hiltgen
3c75113e37 Prevent loading models larger than total memory
Users may not realize the siny new model they're trying to load
fits on their disk, but can't load into system+GPU memory.  Today
we crash, but with this fix, we'll give them a better error message
before even trying to load it.
2024-07-03 14:47:42 -07:00
Daniel Hiltgen
ccd7785859 Merge pull request #5243 from dhiltgen/modelfile_use_mmap
Fix use_mmap for modefiles
2024-07-03 13:59:42 -07:00
royjhan
3b5a4a77f3 Return Correct Prompt Eval Count Regardless of Cache Prompt (#5371)
* openai compatibility

* Revert "openai compatibility"

This reverts commit d3f98a811e.

* remove erroneous subtraction of prompt cache
2024-07-03 13:46:23 -07:00
Daniel Hiltgen
daed0634a9 Merge pull request #5467 from dhiltgen/bogus_cpu_mac_error
Fix corner cases on tmp cleaner on mac
2024-07-03 13:39:36 -07:00
Daniel Hiltgen
0d4dd707bc Merge pull request #5465 from dhiltgen/better_cuda_logging
Better nvidia GPU discovery logging
2024-07-03 13:12:22 -07:00
Daniel Hiltgen
0e982bc1f4 Fix corner cases on tmp cleaner on mac
When ollama is running a long time, tmp cleaners can remove the
runners.  This tightens up a few corner cases on arm macs where
we failed with "server cpu not listed in available servers map[]"
2024-07-03 13:10:14 -07:00
Daniel Hiltgen
6298f49816 Fix clip model loading with unicode paths
On windows, if the model dir contained unicode characters
clip models would fail to load.  This fixes the file name
handling in clip.cpp to support utf16 on windows.
2024-07-03 12:46:36 -07:00
Daniel Hiltgen
ef757da2c9 Better nvidia GPU discovery logging
Refine the way we log GPU discovery to improve the non-debug
output, and report more actionable log messages when possible
to help users troubleshoot on their own.
2024-07-03 10:50:40 -07:00
Michael Yang
e5352297d9 Merge pull request #5448 from ollama/mxyng/fix-generate
use model template by default
2024-07-02 16:48:06 -07:00
Michael Yang
65a5040e09 fix generate template 2024-07-02 16:42:17 -07:00
royjhan
d626b99b54 OpenAI: v1/completions compatibility (#5209)
* OpenAI v1 models

* Refactor Writers

* Add Test

Co-Authored-By: Attila Kerekes

* Credit Co-Author

Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com>

* Empty List Testing

* Use Namespace for Ownedby

* Update Test

* Add back envconfig

* v1/models docs

* Use ModelName Parser

* Test Names

* Remove Docs

* Clean Up

* Test name

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Add Middleware for Chat and List

* Completions Endpoint

* Testing Cleanup

* Test with Fatal

* Add functionality to chat test

* Rename function

* float types

* type cleanup

* cleaning

* more cleaning

* Extra test cases

* merge conflicts

* merge conflicts

* merge conflicts

* merge conflicts

* cleaning

* cleaning

---------

Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-07-02 16:01:45 -07:00
Michael Yang
dddb58a38b Merge pull request #5051 from ollama/mxyng/capabilities
add model capabilities
2024-07-02 14:26:07 -07:00
Michael Yang
400056e154 Merge pull request #5420 from ollama/mxyng/insecure-path
err on insecure path
2024-07-02 14:03:23 -07:00
Daniel Hiltgen
d2f19024d0 Merge pull request #5442 from dhiltgen/concurrency_docs
Add windows radeon concurrency note
2024-07-02 12:47:47 -07:00
Daniel Hiltgen
69c04eecc4 Add windows radeon concurreny note 2024-07-02 12:46:14 -07:00
royjhan
996bb1b85e OpenAI: /v1/models and /v1/models/{model} compatibility (#5007)
* OpenAI v1 models

* Refactor Writers

* Add Test

Co-Authored-By: Attila Kerekes

* Credit Co-Author

Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com>

* Empty List Testing

* Use Namespace for Ownedby

* Update Test

* Add back envconfig

* v1/models docs

* Use ModelName Parser

* Test Names

* Remove Docs

* Clean Up

* Test name

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Add Middleware for Chat and List

* Testing Cleanup

* Test with Fatal

* Add functionality to chat test

* OpenAI: /v1/models/{model} compatibility (#5028)

* Retrieve Model

* OpenAI Delete Model

* Retrieve Middleware

* Remove Delete from Branch

* Update Test

* Middleware Test File

* Function name

* Cleanup

* Test Update

* Test Update

---------

Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-07-02 11:50:56 -07:00
Daniel Hiltgen
422dcc3856 Merge pull request #5439 from dhiltgen/fix_centos_7_build
Switch ARM64 container image base to rocky 8
2024-07-02 11:01:15 -07:00
Daniel Hiltgen
020bd60ab2 Switch amd container image base to rocky 8
The centos 7 arm mirrors have disappeared due to the EOL 2 days
ago, and the vault sed workaround which works for x86 doesn't work for arm.
2024-07-02 10:34:47 -07:00
Daniel Hiltgen
8e277b72bb Merge pull request #5438 from dhiltgen/fix_centos_7_build
Centos 7 EOL broke mirrors
2024-07-02 09:28:00 -07:00
Daniel Hiltgen
4f67b39d26 Centos 7 EOL broke mirrors
As of July 1st 2024: Could not resolve host: mirrorlist.centos.org
This is expected due to EOL dates.
2024-07-02 09:22:17 -07:00
Josh
2425281317 Merge pull request #5336 from ollama/jyan/from-errors
fix: trim spaces for FROM argument, don't trim inside of quotes
2024-07-01 16:32:46 -07:00
Josh
0403e9860e Merge pull request #5421 from ollama/jyan/ver
fix: add unsupported architecture message for linux/windows
2024-07-01 16:32:14 -07:00
Josh Yan
33a65e3ba3 error 2024-07-01 16:04:13 -07:00
Michael Yang
88bcd79bb9 err on insecure path 2024-07-01 15:55:59 -07:00
Josh Yan
7e571f95f0 trimspace test case 2024-07-01 11:07:48 -07:00
Michael Yang
da8e2a0447 use kvs to detect embedding models 2024-07-01 10:47:43 -07:00
Michael Yang
a30915bde1 add capabilities 2024-07-01 10:47:43 -07:00
Michael Yang
58e3fff311 rename templates to template 2024-07-01 10:40:54 -07:00
Michael Yang
3f0b309ad4 remove ManifestV2 2024-07-01 10:40:54 -07:00
Daniel Hiltgen
e70610ef06 Merge pull request #5410 from dhiltgen/ctx_cleanup
Fix case for NumCtx
2024-07-01 09:54:20 -07:00
Daniel Hiltgen
dfded7e075 Merge pull request #5364 from dhiltgen/concurrency_docs
Document concurrent behavior and settings
2024-07-01 09:49:48 -07:00
Daniel Hiltgen
173b550438 Remove default auto from help message
This may confuse users thinking "auto" is an acceptable string - it must be numeric
2024-07-01 09:48:05 -07:00
Daniel Hiltgen
cff3f44f4a Fix case for NumCtx 2024-07-01 09:43:59 -07:00
Josh Yan
26e4e66faf updated parsefile test 2024-07-01 09:43:49 -07:00
Daniel Hiltgen
97c9e11768 Switch use_mmap to a pointer type
This uses nil as undefined for a cleaner implementation.
2024-07-01 08:44:59 -07:00
Daniel Hiltgen
3518aaef33 Merge pull request #4218 from dhiltgen/auto_parallel
Enable concurrency by default
2024-07-01 08:32:29 -07:00
RAPID ARCHITECT
1963c00201 Update README.md (#5214)
* Update README.md

Added Mesop example to web & desktop

* Update README.md

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-06-30 22:00:57 -04:00
Eduard
27402cb7a2 Update gpu.md (#5382)
Runs fine on a NVIDIA GeForce GTX 1050 Ti
2024-06-30 21:48:51 -04:00
Jeffrey Morgan
c1218199cf Update api.md 2024-06-29 16:22:49 -07:00
Jeffrey Morgan
717f7229eb Do not shift context for sliding window models (#5368)
* Do not shift context for sliding window models

* truncate prompt > 2/3 tokens

* only target gemma2
2024-06-28 19:39:31 -07:00
Daniel Hiltgen
aae56abb7c Document concurrent behavior and settings 2024-06-28 13:15:57 -07:00
royjhan
5f034f5b63 Include Show Info in Interactive (#5342) 2024-06-28 13:15:52 -07:00
royjhan
b910fa9010 Ollama Show: Check for Projector Type (#5307)
* Check exists projtype

* Maintain Ordering
2024-06-28 11:30:16 -07:00
royjhan
6d4219083c Update docs (#5312) 2024-06-28 09:58:14 -07:00
Michael Yang
1ed4f521c4 Merge pull request #5340 from ollama/mxyng/mem
gemma2 graph
2024-06-27 14:26:49 -07:00
Michael Yang
de2163dafd gemma2 graph 2024-06-27 13:34:52 -07:00
Josh Yan
9bd00041fa trim all params 2024-06-27 11:18:38 -07:00
Josh Yan
4e986a823c unquote, trimp space 2024-06-27 10:59:15 -07:00
Michael
2cc7d05012 update readme for gemma 2 (#5333)
* update readme for gemma 2
2024-06-27 12:45:16 -04:00
Michael Yang
123a722a6f zip: prevent extracting files into parent dirs (#5314) 2024-06-26 21:38:21 -07:00
Jeffrey Morgan
4d311eb731 llm: architecture patch (#5316) 2024-06-26 21:38:12 -07:00
Blake Mizerany
cb42e607c5 llm: speed up gguf decoding by a lot (#5246)
Previously, some costly things were causing the loading of GGUF files
and their metadata and tensor information to be VERY slow:

  * Too many allocations when decoding strings
  * Hitting disk for each read of each key and value, resulting in a
    not-okay amount of syscalls/disk I/O.

The show API is now down to 33ms from 800ms+ for llama3 on a macbook pro
m3.

This commit also prevents collecting large arrays of values when
decoding GGUFs (if desired). When such keys are encountered, their
values are null, and are encoded as such in JSON.

Also, this fixes a broken test that was not encoding valid GGUF.
2024-06-24 21:47:52 -07:00
Blake Mizerany
2aa91a937b cmd: defer stating model info until necessary (#5248)
This commit changes the 'ollama run' command to defer fetching model
information until it really needs it. That is, when in interactive mode.

It also removes one such case where the model information is fetch in
duplicate, just before calling generateInteractive and then again, first
thing, in generateInteractive.

This positively impacts the performance of the command:

    ; time ./before run llama3 'hi'
    Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat?

    ./before run llama3 'hi'  0.02s user 0.01s system 2% cpu 1.168 total
    ; time ./before run llama3 'hi'
    Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat?

    ./before run llama3 'hi'  0.02s user 0.01s system 2% cpu 1.220 total
    ; time ./before run llama3 'hi'
    Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat?

    ./before run llama3 'hi'  0.02s user 0.01s system 2% cpu 1.217 total
    ; time ./after run llama3 'hi'
    Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat?

    ./after run llama3 'hi'  0.02s user 0.01s system 4% cpu 0.652 total
    ; time ./after run llama3 'hi'
    Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat?

    ./after run llama3 'hi'  0.01s user 0.01s system 5% cpu 0.498 total
    ; time ./after run llama3 'hi'
    Hi! It's nice to meet you. Is there something I can help you with or would you like to chat?

    ./after run llama3 'hi'  0.01s user 0.01s system 3% cpu 0.479 total
    ; time ./after run llama3 'hi'
    Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat?

    ./after run llama3 'hi'  0.02s user 0.01s system 5% cpu 0.507 total
    ; time ./after run llama3 'hi'
    Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat?

    ./after run llama3 'hi'  0.02s user 0.01s system 5% cpu 0.507 total
2024-06-24 20:14:03 -07:00
Daniel Hiltgen
ccef9431c8 Merge pull request #5205 from dhiltgen/modelfile_use_mmap
Fix use_mmap parsing for modelfiles
2024-06-21 16:30:36 -07:00
Daniel Hiltgen
642cee1342 Sort the ps output
Provide consistent ordering for the ps command - longest duration listed first
2024-06-21 15:59:41 -07:00
royjhan
9a9e7d83c4 Docs (#5149) 2024-06-21 15:52:09 -07:00
Daniel Hiltgen
9929751cc8 Disable concurrency for AMD + Windows
Until ROCm v6.2 ships, we wont be able to get accurate free memory
reporting on windows, which makes automatic concurrency too risky.
Users can still opt-in but will need to pay attention to model sizes otherwise they may thrash/page VRAM or cause OOM crashes.
All other platforms and GPUs have accurate VRAM reporting wired
up now, so we can turn on concurrency by default.
2024-06-21 15:45:05 -07:00
Daniel Hiltgen
17b7186cd7 Enable concurrency by default
This adjusts our default settings to enable multiple models and parallel
requests to a single model.  Users can still override these by the same
env var settings as before.  Parallel has a direct impact on
num_ctx, which in turn can have a significant impact on small VRAM GPUs
so this change also refines the algorithm so that when parallel is not
explicitly set by the user, we try to find a reasonable default that fits
the model on their GPU(s).  As before, multiple models will only load
concurrently if they fully fit in VRAM.
2024-06-21 15:45:05 -07:00
Michael Yang
189a43caa2 Merge pull request #5206 from ollama/mxyng/quantize
fix: quantization with template
2024-06-21 13:44:34 -07:00
Michael Yang
e835ef1836 fix: quantization with template 2024-06-21 13:39:25 -07:00
Daniel Hiltgen
7e7749224c Fix use_mmap parsing for modelfiles
Add the new tristate parsing logic for the code path for modelfiles,
as well as a unit test.
2024-06-21 12:27:19 -07:00
Daniel Hiltgen
c7c2f3bc22 Merge pull request #5194 from dhiltgen/linux_mmap_auto
Refine mmap default logic on linux
2024-06-20 11:44:08 -07:00
Daniel Hiltgen
54a79d6a8a Merge pull request #5125 from dhiltgen/fedora39
Bump latest fedora cuda repo to 39
2024-06-20 11:27:24 -07:00
Daniel Hiltgen
5bf5aeec01 Refine mmap default logic on linux
If we try to use mmap when the model is larger than the system free space, loading is slower than the no-mmap approach.
2024-06-20 11:07:04 -07:00
Daniel Hiltgen
1a1c99e334 Bump latest fedora cuda repo to 39 2024-06-18 17:13:54 -07:00
173 changed files with 4213 additions and 1661 deletions

View File

@@ -58,6 +58,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
env: env:
GOARCH: ${{ matrix.arch }} GOARCH: ${{ matrix.arch }}
CGO_ENABLED: '1'
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
@@ -79,6 +80,7 @@ jobs:
- run: go generate -x ./... - run: go generate -x ./...
if: ${{ ! startsWith(matrix.os, 'windows-') }} if: ${{ ! startsWith(matrix.os, 'windows-') }}
name: 'Unix Go Generate' name: 'Unix Go Generate'
- run: go build .
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries name: ${{ matrix.os }}-${{ matrix.arch }}-libraries

View File

@@ -70,12 +70,12 @@ RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64 FROM --platform=linux/arm64 rockylinux:8 AS cpu-builder-arm64
ARG CMAKE_VERSION ARG CMAKE_VERSION
ARG GOLANG_VERSION ARG GOLANG_VERSION
COPY ./scripts/rh_linux_deps.sh / COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/ollama/ollama/ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
ARG OLLAMA_CUSTOM_CPU_DEFS ARG OLLAMA_CUSTOM_CPU_DEFS
ARG CGO_CFLAGS ARG CGO_CFLAGS

View File

@@ -53,8 +53,8 @@ Here are some example models that can be downloaded:
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` | | Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` | | Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
| Phi 3 Medium | 14B | 7.9GB | `ollama run phi3:medium` | | Phi 3 Medium | 14B | 7.9GB | `ollama run phi3:medium` |
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` | | Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` | | Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
| Mistral | 7B | 4.1GB | `ollama run mistral` | | Mistral | 7B | 4.1GB | `ollama run mistral` |
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` | | Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` | | Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
@@ -182,6 +182,12 @@ $ ollama run llama3 "Summarize this file: $(cat README.md)"
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications. Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
``` ```
### Show model information
```
ollama show llama3
```
### List models on your computer ### List models on your computer
``` ```
@@ -286,6 +292,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama) - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
### Terminal ### Terminal

View File

@@ -17,14 +17,20 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"path/filepath"
"runtime" "runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@@ -374,3 +380,27 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil return version.Version, nil
} }
func Authorization(ctx context.Context, request *http.Request) (string, error) {
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
if err != nil {
return "", err
}
defer knownHostsFile.Close()
token, err := auth.Sign(ctx, data)
if err != nil {
return "", err
}
// interleave request data into the token
key, sig, _ := strings.Cut(token, ":")
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
}

View File

@@ -159,49 +159,18 @@ type Options struct {
// Runner options which must be set when the model is loaded into memory // Runner options which must be set when the model is loaded into memory
type Runner struct { type Runner struct {
UseNUMA bool `json:"numa,omitempty"` UseNUMA bool `json:"numa,omitempty"`
NumCtx int `json:"num_ctx,omitempty"` NumCtx int `json:"num_ctx,omitempty"`
NumBatch int `json:"num_batch,omitempty"` NumBatch int `json:"num_batch,omitempty"`
NumGPU int `json:"num_gpu,omitempty"` NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"` MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"` LowVRAM bool `json:"low_vram,omitempty"`
F16KV bool `json:"f16_kv,omitempty"` F16KV bool `json:"f16_kv,omitempty"`
LogitsAll bool `json:"logits_all,omitempty"` LogitsAll bool `json:"logits_all,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"` VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap TriState `json:"use_mmap,omitempty"` UseMMap *bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"` UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"` NumThread int `json:"num_thread,omitempty"`
}
type TriState int
const (
TriStateUndefined TriState = -1
TriStateFalse TriState = 0
TriStateTrue TriState = 1
)
func (b *TriState) UnmarshalJSON(data []byte) error {
var v bool
if err := json.Unmarshal(data, &v); err != nil {
return err
}
if v {
*b = TriStateTrue
}
*b = TriStateFalse
return nil
}
func (b *TriState) MarshalJSON() ([]byte, error) {
if *b == TriStateUndefined {
return nil, nil
}
var v bool
if *b == TriStateTrue {
v = true
}
return json.Marshal(v)
} }
// EmbeddingRequest is the request passed to [Client.Embeddings]. // EmbeddingRequest is the request passed to [Client.Embeddings].
@@ -345,6 +314,13 @@ type ProcessModelResponse struct {
SizeVRAM int64 `json:"size_vram"` SizeVRAM int64 `json:"size_vram"`
} }
type RetrieveModelResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type TokenResponse struct { type TokenResponse struct {
Token string `json:"token"` Token string `json:"token"`
} }
@@ -437,19 +413,6 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
continue continue
} }
if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
val, ok := val.(bool)
if !ok {
return fmt.Errorf("option %q must be of type boolean", key)
}
if val {
field.SetInt(int64(TriStateTrue))
} else {
field.SetInt(int64(TriStateFalse))
}
continue
}
switch field.Kind() { switch field.Kind() {
case reflect.Int: case reflect.Int:
switch t := val.(type) { switch t := val.(type) {
@@ -496,6 +459,17 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
slice[i] = str slice[i] = str
} }
field.Set(reflect.ValueOf(slice)) field.Set(reflect.ValueOf(slice))
case reflect.Pointer:
var b bool
if field.Type() == reflect.TypeOf(&b) {
val, ok := val.(bool)
if !ok {
return fmt.Errorf("option %q must be of type boolean", key)
}
field.Set(reflect.ValueOf(&val))
} else {
return fmt.Errorf("unknown type loading config params: %v %v", field.Kind(), field.Type())
}
default: default:
return fmt.Errorf("unknown type loading config params: %v", field.Kind()) return fmt.Errorf("unknown type loading config params: %v", field.Kind())
} }
@@ -538,7 +512,7 @@ func DefaultOptions() Options {
LowVRAM: false, LowVRAM: false,
F16KV: true, F16KV: true,
UseMLock: false, UseMLock: false,
UseMMap: TriStateUndefined, UseMMap: nil,
UseNUMA: false, UseNUMA: false,
}, },
} }
@@ -635,6 +609,17 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
case reflect.Slice: case reflect.Slice:
// TODO: only string slices are supported right now // TODO: only string slices are supported right now
out[key] = vals out[key] = vals
case reflect.Pointer:
var b bool
if field.Type() == reflect.TypeOf(&b) {
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = &boolVal
} else {
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
default: default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
} }

View File

@@ -2,6 +2,7 @@ package api
import ( import (
"encoding/json" "encoding/json"
"fmt"
"math" "math"
"testing" "testing"
"time" "time"
@@ -107,25 +108,27 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
} }
func TestUseMmapParsingFromJSON(t *testing.T) { func TestUseMmapParsingFromJSON(t *testing.T) {
tr := true
fa := false
tests := []struct { tests := []struct {
name string name string
req string req string
exp TriState exp *bool
}{ }{
{ {
name: "Undefined", name: "Undefined",
req: `{ }`, req: `{ }`,
exp: TriStateUndefined, exp: nil,
}, },
{ {
name: "True", name: "True",
req: `{ "use_mmap": true }`, req: `{ "use_mmap": true }`,
exp: TriStateTrue, exp: &tr,
}, },
{ {
name: "False", name: "False",
req: `{ "use_mmap": false }`, req: `{ "use_mmap": false }`,
exp: TriStateFalse, exp: &fa,
}, },
} }
@@ -141,3 +144,67 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
}) })
} }
} }
func TestUseMmapFormatParams(t *testing.T) {
tr := true
fa := false
tests := []struct {
name string
req map[string][]string
exp *bool
err error
}{
{
name: "True",
req: map[string][]string{
"use_mmap": {"true"},
},
exp: &tr,
err: nil,
},
{
name: "False",
req: map[string][]string{
"use_mmap": {"false"},
},
exp: &fa,
err: nil,
},
{
name: "Numeric True",
req: map[string][]string{
"use_mmap": {"1"},
},
exp: &tr,
err: nil,
},
{
name: "Numeric False",
req: map[string][]string{
"use_mmap": {"0"},
},
exp: &fa,
err: nil,
},
{
name: "invalid string",
req: map[string][]string{
"use_mmap": {"foo"},
},
exp: nil,
err: fmt.Errorf("invalid bool value [foo]"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resp, err := FormatParams(test.req)
require.Equal(t, test.err, err)
respVal, ok := resp["use_mmap"]
if test.exp != nil {
assert.True(t, ok, "resp: %v", resp)
assert.Equal(t, *test.exp, *respVal.(*bool))
}
})
}
}

View File

@@ -10,42 +10,37 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
const defaultPrivateKey = "id_ed25519" const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) { func keyPath() (ssh.Signer, error) {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return "", err return nil, err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
} }
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath) privateKeyFile, err := os.ReadFile(keyPath)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err return nil, err
} }
privateKey, err := ssh.ParsePrivateKey(privateKeyFile) return ssh.ParsePrivateKey(privateKeyFile)
}
func GetPublicKey() (ssh.PublicKey, error) {
privateKey, err := keyPath()
// if privateKey, try public key directly
if err != nil { if err != nil {
return "", err return nil, err
} }
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) return privateKey.PublicKey(), nil
return strings.TrimSpace(string(publicKey)), nil
} }
func NewNonce(r io.Reader, length int) (string, error) { func NewNonce(r io.Reader, length int) (string, error) {
@@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
} }
func Sign(ctx context.Context, bts []byte) (string, error) { func Sign(ctx context.Context, bts []byte) (string, error) {
keyPath, err := keyPath() privateKey, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil { if err != nil {
return "", err return "", err
} }
// get the pubkey, but remove the type // get the pubkey, but remove the type
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) publicKey, err := GetPublicKey()
parts := bytes.Split(publicKey, []byte(" ")) if err != nil {
return "", err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
parts := bytes.Split(publicKeyBytes, []byte(" "))
if len(parts) < 2 { if len(parts) < 2 {
return "", fmt.Errorf("malformed public key") return "", fmt.Errorf("malformed public key")
} }

View File

@@ -7,6 +7,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
@@ -15,6 +16,7 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
@@ -78,6 +80,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
status := "transferring model data" status := "transferring model data"
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
defer p.Stop()
for i := range modelfile.Commands { for i := range modelfile.Commands {
switch modelfile.Commands[i].Name { switch modelfile.Commands[i].Name {
@@ -112,11 +115,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile path = tempfile
} }
digest, err := createBlob(cmd, client, path) digest, err := createBlob(cmd, client, path, spinner)
if err != nil { if err != nil {
return err return err
} }
modelfile.Commands[i].Args = "@" + digest modelfile.Commands[i].Args = "@" + digest
} }
} }
@@ -138,7 +140,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner.Stop() spinner.Stop()
status = resp.Status status = resp.Status
spinner = progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
} }
@@ -162,9 +164,6 @@ func tempZipFiles(path string) (string, error) {
} }
defer tempfile.Close() defer tempfile.Close()
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()
detectContentType := func(path string) (string, error) { detectContentType := func(path string) (string, error) {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
@@ -233,6 +232,9 @@ func tempZipFiles(path string) (string, error) {
files = append(files, tks...) files = append(files, tks...)
} }
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()
for _, file := range files { for _, file := range files {
f, err := os.Open(file) f, err := os.Open(file)
if err != nil { if err != nil {
@@ -263,13 +265,22 @@ func tempZipFiles(path string) (string, error) {
return tempfile.Name(), nil return tempfile.Name(), nil
} }
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { var ErrBlobExists = errors.New("blob exists")
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
bin, err := os.Open(path) bin, err := os.Open(path)
if err != nil { if err != nil {
return "", err return "", err
} }
defer bin.Close() defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
hash := sha256.New() hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil { if _, err := io.Copy(hash, bin); err != nil {
return "", err return "", err
@@ -279,46 +290,158 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
return "", err return "", err
} }
var pw progressWriter
status := "transferring model data 0%"
spinner.SetMessage(status)
ticker := time.NewTicker(60 * time.Millisecond)
done := make(chan struct{})
defer close(done)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n/fileSize)))
case <-done:
spinner.SetMessage("transferring model data 100%")
return
}
}
}()
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
// We check if we can find the models directory locally
// If we can, we return the path to the directory
// If we can't, we return an error
// If the blob exists already, we return the digest
dest, err := getLocalPath(cmd.Context(), digest)
if errors.Is(err, ErrBlobExists) {
return digest, nil
}
// Successfully found the model directory
if err == nil {
// Copy blob in via OS specific copy
// Linux errors out to use io.copy
err = localCopy(path, dest)
if err == nil {
return digest, nil
}
// Default copy using io.copy
err = defaultCopy(path, dest)
if err == nil {
return digest, nil
}
}
// If at any point copying the blob over locally fails, we default to the copy through the server
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err return "", err
} }
return digest, nil return digest, nil
} }
func RunHandler(cmd *cobra.Command, args []string) error { type progressWriter struct {
client, err := api.ClientFromEnvironment() n int64
}
func (w *progressWriter) Write(p []byte) (n int, err error) {
w.n += int64(len(p))
return len(p), nil
}
func getLocalPath(ctx context.Context, digest string) (string, error) {
ollamaHost := envconfig.Host
client := http.DefaultClient
base := &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
}
data, err := json.Marshal(digest)
if err != nil { if err != nil {
return "", err
}
reqBody := bytes.NewReader(data)
path := fmt.Sprintf("/api/blobs/%s", digest)
requestURL := base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
if err != nil {
return "", err
}
authz, err := api.Authorization(ctx, request)
if err != nil {
return "", err
}
request.Header.Set("Authorization", authz)
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
request.Header.Set("X-Redirect-Create", "1")
resp, err := client.Do(request)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusTemporaryRedirect {
dest := resp.Header.Get("LocalLocation")
return dest, nil
}
return "", ErrBlobExists
}
func defaultCopy(path string, dest string) error {
// This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err return err
} }
name := args[0] // Copy blob over
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("could not open source file: %v", err)
}
defer sourceFile.Close()
// check if the model exists on the server destFile, err := os.Create(dest)
show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name}) if err != nil {
var statusError api.StatusError return fmt.Errorf("could not create destination file: %v", err)
switch { }
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: defer destFile.Close()
if err := PullHandler(cmd, []string{name}); err != nil {
return err
}
show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name}) _, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
if err != nil { if err != nil {
return err return fmt.Errorf("error copying file: %v", err)
}
case err != nil:
return err
} }
err = destFile.Sync()
if err != nil {
return fmt.Errorf("error flushing file: %v", err)
}
return nil
}
func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true interactive := true
opts := runOptions{ opts := runOptions{
Model: args[0], Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color", WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{}, Options: map[string]interface{}{},
MultiModal: slices.Contains(show.Details.Families, "clip"),
ParentModel: show.Details.ParentModel,
} }
format, err := cmd.Flags().GetString("format") format, err := cmd.Flags().GetString("format")
@@ -362,11 +485,38 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
opts.WordWrap = !nowrap opts.WordWrap = !nowrap
if !interactive { // Fill out the rest of the options based on information about the
return generate(cmd, opts) // model.
client, err := api.ClientFromEnvironment()
if err != nil {
return err
} }
return generateInteractive(cmd, opts) name := args[0]
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
}
return info, err
}()
if err != nil {
return err
}
opts.MultiModal = slices.Contains(info.Details.Families, "clip")
opts.ParentModel = info.Details.ParentModel
opts.Messages = append(opts.Messages, info.Messages...)
if interactive {
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
} }
func errFromUnknownKey(unknownKeyErr error) error { func errFromUnknownKey(unknownKeyErr error) error {
@@ -378,11 +528,13 @@ func errFromUnknownKey(unknownKeyErr error) error {
if len(matches) > 0 { if len(matches) > 0 {
serverPubKey := matches[0] serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey() publicKey, err := auth.GetPublicKey()
if err != nil { if err != nil {
return unknownKeyErr return unknownKeyErr
} }
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
if runtime.GOOS == "linux" && serverPubKey != localPubKey { if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key // try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub") svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
@@ -623,13 +775,13 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified") return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
} }
if flagsSet == 1 { req := api.ShowRequest{Name: args[0]}
req := api.ShowRequest{Name: args[0]} resp, err := client.Show(cmd.Context(), &req)
resp, err := client.Show(cmd.Context(), &req) if err != nil {
if err != nil { return err
return err }
}
if flagsSet == 1 {
switch showType { switch showType {
case "license": case "license":
fmt.Println(resp.License) fmt.Println(resp.License)
@@ -646,12 +798,12 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
req := api.ShowRequest{Name: args[0]} showInfo(resp)
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return err
}
return nil
}
func showInfo(resp *api.ShowResponse) {
arch := resp.ModelInfo["general.architecture"].(string) arch := resp.ModelInfo["general.architecture"].(string)
modelData := [][]string{ modelData := [][]string{
@@ -671,11 +823,17 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
projectorData := [][]string{ projectorData := [][]string{
{"arch", "clip"}, {"arch", "clip"},
{"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))}, {"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))},
{"projector type", resp.ProjectorInfo["clip.projector_type"].(string)},
{"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))},
{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))},
} }
if projectorType, ok := resp.ProjectorInfo["clip.projector_type"]; ok {
projectorData = append(projectorData, []string{"projector type", projectorType.(string)})
}
projectorData = append(projectorData,
[]string{"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))},
[]string{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))},
)
mainTableData = append(mainTableData, mainTableData = append(mainTableData,
[]string{"Projector"}, []string{"Projector"},
[]string{renderSubTable(projectorData, false)}, []string{renderSubTable(projectorData, false)},
@@ -704,8 +862,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
} }
table.Render() table.Render()
return nil
} }
func renderSubTable(data [][]string, file bool) string { func renderSubTable(data [][]string, file bool) string {

23
cmd/copy_darwin.go Normal file
View File

@@ -0,0 +1,23 @@
package cmd
import (
"os"
"path/filepath"
"golang.org/x/sys/unix"
)
func localCopy(src, target string) error {
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
err := unix.Clonefile(src, target, 0)
if err != nil {
return err
}
return nil
}

7
cmd/copy_linux.go Normal file
View File

@@ -0,0 +1,7 @@
package cmd
import "errors"
func localCopy(src, target string) error {
return errors.New("no local copy implementation for linux")
}

67
cmd/copy_windows.go Normal file
View File

@@ -0,0 +1,67 @@
//go:build windows
// +build windows
package cmd
import (
"os"
"path/filepath"
"syscall"
"unsafe"
)
func localCopy(src, target string) error {
// Create target directory if it doesn't exist
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Open source file
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
// Create target file
targetFile, err := os.Create(target)
if err != nil {
return err
}
defer targetFile.Close()
// Use CopyFileExW to copy the file
err = copyFileEx(src, target)
if err != nil {
return err
}
return nil
}
func copyFileEx(src, dst string) error {
kernel32 := syscall.NewLazyDLL("kernel32.dll")
copyFileEx := kernel32.NewProc("CopyFileExW")
srcPtr, err := syscall.UTF16PtrFromString(src)
if err != nil {
return err
}
dstPtr, err := syscall.UTF16PtrFromString(dst)
if err != nil {
return err
}
r1, _, err := copyFileEx.Call(
uintptr(unsafe.Pointer(srcPtr)),
uintptr(unsafe.Pointer(dstPtr)),
0, 0, 0, 0)
if r1 == 0 {
return err
}
return nil
}

View File

@@ -31,65 +31,40 @@ const (
) )
func loadModel(cmd *cobra.Command, opts *runOptions) error { func loadModel(cmd *cobra.Command, opts *runOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.StopAndClear() defer p.StopAndClear()
spinner := progress.NewSpinner("") spinner := progress.NewSpinner("")
p.Add("", spinner) p.Add("", spinner)
showReq := api.ShowRequest{Name: opts.Model} client, err := api.ClientFromEnvironment()
showResp, err := client.Show(cmd.Context(), &showReq)
if err != nil { if err != nil {
return err return err
} }
opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
opts.ParentModel = showResp.Details.ParentModel
if len(showResp.Messages) > 0 {
opts.Messages = append(opts.Messages, showResp.Messages...)
}
chatReq := &api.ChatRequest{ chatReq := &api.ChatRequest{
Model: opts.Model, Model: opts.Model,
Messages: []api.Message{}, KeepAlive: opts.KeepAlive,
} }
if opts.KeepAlive != nil { return client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
chatReq.KeepAlive = opts.KeepAlive
}
err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
p.StopAndClear() p.StopAndClear()
if len(opts.Messages) > 0 { for _, msg := range opts.Messages {
for _, msg := range opts.Messages { switch msg.Role {
switch msg.Role { case "user":
case "user": fmt.Printf(">>> %s\n", msg.Content)
fmt.Printf(">>> %s\n", msg.Content) case "assistant":
case "assistant": state := &displayResponseState{}
state := &displayResponseState{} displayResponse(msg.Content, opts.WordWrap, state)
displayResponse(msg.Content, opts.WordWrap, state) fmt.Println()
fmt.Println() fmt.Println()
fmt.Println()
}
} }
} }
return nil return nil
}) })
if err != nil {
return err
}
return nil
} }
func generateInteractive(cmd *cobra.Command, opts runOptions) error { func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = make([]api.Message, 0)
err := loadModel(cmd, &opts) err := loadModel(cmd, &opts)
if err != nil { if err != nil {
return err return err
@@ -429,15 +404,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] { switch args[1] {
case "info": case "info":
fmt.Println("Model details:") showInfo(resp)
if len(resp.Details.Families) > 0 {
fmt.Printf("Family %s\n", strings.Join(resp.Details.Families, ", "))
} else if resp.Details.Family != "" {
fmt.Printf("Family %s\n", resp.Details.Family)
}
fmt.Printf("Parameter Size %s\n", resp.Details.ParameterSize)
fmt.Printf("Quantization Level %s\n", resp.Details.QuantizationLevel)
fmt.Println("")
case "license": case "license":
if resp.License == "" { if resp.License == "" {
fmt.Println("No license was specified for this model.") fmt.Println("No license was specified for this model.")

View File

@@ -26,7 +26,7 @@ All durations are returned in nanoseconds.
### Streaming responses ### Streaming responses
Certain endpoints stream responses as JSON objects and can optional return non-streamed responses. Certain endpoints stream responses as JSON objects. Streaming can be disabled by providing `{"stream": false}` for these endpoints.
## Generate a completion ## Generate a completion

View File

@@ -104,7 +104,7 @@ like to use. For example, to compile an optimized binary for an Intel i9-9880H,
you might use: you might use:
``` ```
OLLAMA_CUSTOM_CPU_DEFS="-DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_F16C=on -DLLAMA_FMA=on" go generate ./... OLLAMA_CUSTOM_CPU_DEFS="-DGGML_AVX=on -DGGML_AVX2=on -DGGML_F16C=on -DGGML_FMA=on" go generate ./...
go build . go build .
``` ```

View File

@@ -257,3 +257,19 @@ If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` AP
## How do I manage the maximum number of requests the Ollama server can queue? ## How do I manage the maximum number of requests the Ollama server can queue?
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`. If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
## How does Ollama handle concurrent requests?
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing.
If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
Parallel request processing for a given model results in increasing the context size by the number of parallel requests. For example, a 2K context with 4 parallel requests will result in an 8K context and additional memory allocation.
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

View File

@@ -18,7 +18,7 @@ Check your compute compatibility to see if your card is supported:
| | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` | | | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` |
| 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` | | 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` |
| 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` | | 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` |
| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050` | | | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` |
| | Quadro | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` | | | Quadro | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` |
| | Tesla | `P40` `P4` | | | Tesla | `P40` `P4` |
| 6.0 | NVIDIA | `Tesla P100` `Quadro GP100` | | 6.0 | NVIDIA | `Tesla P100` `Quadro GP100` |

View File

@@ -65,6 +65,7 @@ curl http://localhost:11434/v1/chat/completions \
} }
] ]
}' }'
``` ```
## Endpoints ## Endpoints
@@ -104,7 +105,6 @@ curl http://localhost:11434/v1/chat/completions \
#### Notes #### Notes
- `finish_reason` will always be `stop`
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached - `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
## Models ## Models

View File

@@ -70,14 +70,18 @@ curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/ If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
## Container fails to run on NVIDIA GPU ## NVIDIA GPU Discovery
Make sure you've set up the container runtime first as described in [docker.md](./docker.md) When Ollama starts up, it takes inventory of the GPUs present in the system to determine compatibility and how much VRAM is available. Sometimes this discovery can fail to find your GPUs. In general, running the latest driver will yield the best results.
Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem ### Linux NVIDIA Troubleshooting
- Is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU. If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
- Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
- If you are using a container, is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
- Is the uvm driver loaded? `sudo nvidia-modprobe -u`
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm` - Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
- Try rebooting - Try rebooting
- Make sure you're running the latest nvidia drivers - Make sure you're running the latest nvidia drivers
@@ -85,3 +89,8 @@ Sometimes the container runtime can have difficulties initializing the GPU. When
If none of those resolve the problem, gather additional information and file an issue: If none of those resolve the problem, gather additional information and file an issue:
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs - Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia` - Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
## Windows Terminal Errors
Older versions of Windows 10 (e.g., 21H1) are known to have a bug where the standard terminal program does not display control characters correctly. This can result in a long string of strings like `←[?25h←[?25l` being displayed, sometimes erroring with `The parameter is incorrect` To resolve this problem, please update to Win 10 22H1 or newer.

View File

@@ -19,7 +19,7 @@ Logs will often be helpful in diagnosing the problem (see
## System Requirements ## System Requirements
* Windows 10 or newer, Home or Pro * Windows 10 22H2 or newer, Home or Pro
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card * NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card * AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card

View File

@@ -4,12 +4,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
) )
type OllamaHost struct { type OllamaHost struct {
@@ -34,17 +36,17 @@ var (
// Set via OLLAMA_HOST in the environment // Set via OLLAMA_HOST in the environment
Host *OllamaHost Host *OllamaHost
// Set via OLLAMA_KEEP_ALIVE in the environment // Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive string KeepAlive time.Duration
// Set via OLLAMA_LLM_LIBRARY in the environment // Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment // Set via OLLAMA_MAX_LOADED_MODELS in the environment
MaxRunners int MaxRunners int
// Set via OLLAMA_MAX_QUEUE in the environment // Set via OLLAMA_MAX_QUEUE in the environment
MaxQueuedRequests int MaxQueuedRequests int
// Set via OLLAMA_MODELS in the environment
ModelsDir string
// Set via OLLAMA_MAX_VRAM in the environment // Set via OLLAMA_MAX_VRAM in the environment
MaxVRAM uint64 MaxVRAM uint64
// Set via OLLAMA_MODELS in the environment
ModelsDir string
// Set via OLLAMA_NOHISTORY in the environment // Set via OLLAMA_NOHISTORY in the environment
NoHistory bool NoHistory bool
// Set via OLLAMA_NOPRUNE in the environment // Set via OLLAMA_NOPRUNE in the environment
@@ -85,13 +87,13 @@ func AsMap() map[string]EnvVar {
"OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"}, "OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
@@ -129,9 +131,10 @@ func clean(key string) string {
func init() { func init() {
// default values // default values
NumParallel = 1 NumParallel = 0 // Autoselect
MaxRunners = 1 MaxRunners = 0 // Autoselect
MaxQueuedRequests = 512 MaxQueuedRequests = 512
KeepAlive = 5 * time.Minute
LoadConfig() LoadConfig()
} }
@@ -205,8 +208,8 @@ func LoadConfig() {
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" { if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
val, err := strconv.Atoi(onp) val, err := strconv.Atoi(onp)
if err != nil || val <= 0 { if err != nil {
slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err) slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err)
} else { } else {
NumParallel = val NumParallel = val
} }
@@ -251,7 +254,7 @@ func LoadConfig() {
if maxRunners != "" { if maxRunners != "" {
m, err := strconv.Atoi(maxRunners) m, err := strconv.Atoi(maxRunners)
if err != nil { if err != nil {
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err) slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
} else { } else {
MaxRunners = m MaxRunners = m
} }
@@ -260,13 +263,16 @@ func LoadConfig() {
if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" { if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
p, err := strconv.Atoi(onp) p, err := strconv.Atoi(onp)
if err != nil || p <= 0 { if err != nil || p <= 0 {
slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err) slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err)
} else { } else {
MaxQueuedRequests = p MaxQueuedRequests = p
} }
} }
KeepAlive = clean("OLLAMA_KEEP_ALIVE") ka := clean("OLLAMA_KEEP_ALIVE")
if ka != "" {
loadKeepAlive(ka)
}
var err error var err error
ModelsDir, err = getModelsDir() ModelsDir, err = getModelsDir()
@@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
Port: port, Port: port,
}, nil }, nil
} }
func loadKeepAlive(ka string) {
v, err := strconv.Atoi(ka)
if err != nil {
d, err := time.ParseDuration(ka)
if err == nil {
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
} else {
d := time.Duration(v) * time.Second
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
}

View File

@@ -2,8 +2,10 @@ package envconfig
import ( import (
"fmt" "fmt"
"math"
"net" "net"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
t.Setenv("OLLAMA_FLASH_ATTENTION", "1") t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig() LoadConfig()
require.True(t, FlashAttention) require.True(t, FlashAttention)
t.Setenv("OLLAMA_KEEP_ALIVE", "")
LoadConfig()
require.Equal(t, 5*time.Minute, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "3")
LoadConfig()
require.Equal(t, 3*time.Second, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
LoadConfig()
require.Equal(t, 1*time.Hour, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
} }
func TestClientFromEnvironment(t *testing.T) { func TestClientFromEnvironment(t *testing.T) {

3
go.mod
View File

@@ -18,6 +18,7 @@ require (
require ( require (
github.com/agnivade/levenshtein v1.1.1 github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14 github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
@@ -71,7 +72,7 @@ require (
golang.org/x/net v0.25.0 // indirect golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.20.0 golang.org/x/sys v0.20.0
golang.org/x/term v0.20.0 golang.org/x/term v0.20.0
golang.org/x/text v0.15.0 // indirect golang.org/x/text v0.15.0
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

View File

@@ -115,8 +115,6 @@ func AMDGetGPUInfo() []RocmGPUInfo {
continue continue
} }
// TODO revisit this once ROCm v6 is available on windows.
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory)) slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory)) slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
gpuInfo := RocmGPUInfo{ gpuInfo := RocmGPUInfo{
@@ -126,6 +124,9 @@ func AMDGetGPUInfo() []RocmGPUInfo {
TotalMemory: totalMemory, TotalMemory: totalMemory,
FreeMemory: freeMemory, FreeMemory: freeMemory,
}, },
// Free memory reporting on Windows is not reliable until we bump to ROCm v6.2
UnreliableFreeMemory: true,
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
DependencyPath: libDir, DependencyPath: libDir,
MinimumMemory: rocmMinimumMemory, MinimumMemory: rocmMinimumMemory,

View File

@@ -202,7 +202,7 @@ func GetGPUInfo() GpuInfoList {
}() }()
if !bootstrapped { if !bootstrapped {
slog.Debug("Detecting GPUs") slog.Info("looking for compatible GPUs")
needRefresh = false needRefresh = false
cpuCapability = GetCPUCapability() cpuCapability = GetCPUCapability()
var memInfo C.mem_info_t var memInfo C.mem_info_t
@@ -320,6 +320,9 @@ func GetGPUInfo() GpuInfoList {
rocmGPUs = AMDGetGPUInfo() rocmGPUs = AMDGetGPUInfo()
bootstrapped = true bootstrapped = true
if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 {
slog.Info("no compatible GPUs were discovered")
}
} }
// For detected GPUs, load library if not loaded // For detected GPUs, load library if not loaded
@@ -514,7 +517,23 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
defer C.free(unsafe.Pointer(lib)) defer C.free(unsafe.Pointer(lib))
C.nvcuda_init(lib, &resp) C.nvcuda_init(lib, &resp)
if resp.err != nil { if resp.err != nil {
slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err)) // Decide what log level based on the type of error message to help users understand why
msg := C.GoString(resp.err)
switch resp.cudaErr {
case C.CUDA_ERROR_INSUFFICIENT_DRIVER, C.CUDA_ERROR_SYSTEM_DRIVER_MISMATCH:
slog.Warn("version mismatch between driver and cuda driver library - reboot or upgrade may be required", "library", libPath, "error", msg)
case C.CUDA_ERROR_NO_DEVICE:
slog.Info("no nvidia devices detected", "library", libPath)
case C.CUDA_ERROR_UNKNOWN:
slog.Warn("unknown error initializing cuda driver library", "library", libPath, "error", msg)
slog.Warn("see https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for more information")
default:
if strings.Contains(msg, "wrong ELF class") {
slog.Debug("skipping 32bit library", "library", libPath)
} else {
slog.Info("unable to load cuda driver library", "library", libPath, "error", msg)
}
}
C.free(unsafe.Pointer(resp.err)) C.free(unsafe.Pointer(resp.err))
} else { } else {
return int(resp.num_devices), &resp.ch, libPath return int(resp.num_devices), &resp.ch, libPath

View File

@@ -56,7 +56,7 @@ func GetCPUInfo() GpuInfoList {
func GetCPUMem() (memInfo, error) { func GetCPUMem() (memInfo, error) {
return memInfo{ return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()), TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: 0, FreeMemory: uint64(C.getFreeMemory()),
}, nil }, nil
} }

View File

@@ -2,3 +2,4 @@
#include <stdint.h> #include <stdint.h>
uint64_t getRecommendedMaxVRAM(); uint64_t getRecommendedMaxVRAM();
uint64_t getPhysicalMemory(); uint64_t getPhysicalMemory();
uint64_t getFreeMemory();

View File

@@ -1,4 +1,5 @@
// go:build darwin #import <Foundation/Foundation.h>
#import <mach/mach.h>
#include "gpu_info_darwin.h" #include "gpu_info_darwin.h"
uint64_t getRecommendedMaxVRAM() { uint64_t getRecommendedMaxVRAM() {
@@ -8,6 +9,27 @@ uint64_t getRecommendedMaxVRAM() {
return result; return result;
} }
// getPhysicalMemory returns the total physical memory in bytes
uint64_t getPhysicalMemory() { uint64_t getPhysicalMemory() {
return [[NSProcessInfo processInfo] physicalMemory]; return [NSProcessInfo processInfo].physicalMemory;
}
// getFreeMemory returns the total free memory in bytes, including inactive
// memory that can be reclaimed by the system.
uint64_t getFreeMemory() {
mach_port_t host_port = mach_host_self();
mach_msg_type_number_t host_size = sizeof(vm_statistics64_data_t) / sizeof(integer_t);
vm_size_t pagesize;
vm_statistics64_data_t vm_stat;
host_page_size(host_port, &pagesize);
if (host_statistics64(host_port, HOST_VM_INFO64, (host_info64_t)&vm_stat, &host_size) != KERN_SUCCESS) {
return 0;
}
uint64_t free_memory = (uint64_t)vm_stat.free_count * pagesize;
free_memory += (uint64_t)vm_stat.speculative_count * pagesize;
free_memory += (uint64_t)vm_stat.inactive_count * pagesize;
return free_memory;
} }

View File

@@ -7,6 +7,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
CUresult ret; CUresult ret;
resp->err = NULL; resp->err = NULL;
resp->num_devices = 0; resp->num_devices = 0;
resp->cudaErr = CUDA_SUCCESS;
const int buflen = 256; const int buflen = 256;
char buf[buflen + 1]; char buf[buflen + 1];
int i; int i;
@@ -38,6 +39,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
nvcuda_lib_path, msg); nvcuda_lib_path, msg);
free(msg); free(msg);
resp->err = strdup(buf); resp->err = strdup(buf);
resp->cudaErr = -1;
return; return;
} }
@@ -52,6 +54,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
msg); msg);
free(msg); free(msg);
resp->err = strdup(buf); resp->err = strdup(buf);
resp->cudaErr = -1;
return; return;
} }
} }
@@ -61,12 +64,9 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
LOG(resp->ch.verbose, "cuInit err: %d\n", ret); LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle); UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL; resp->ch.handle = NULL;
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) { snprintf(buf, buflen, "cuda driver library init failure: %d", ret);
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return;
}
snprintf(buf, buflen, "nvcuda init failure: %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
resp->cudaErr = ret;
return; return;
} }
@@ -91,6 +91,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
resp->ch.handle = NULL; resp->ch.handle = NULL;
snprintf(buf, buflen, "unable to get device count: %d", ret); snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
resp->cudaErr = ret;
return; return;
} }
} }
@@ -106,13 +107,13 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
if (h.handle == NULL) { if (h.handle == NULL) {
resp->err = strdup("nvcuda handle isn't initialized"); resp->err = strdup("cuda driver library handle isn't initialized");
return; return;
} }
ret = (*h.cuDeviceGet)(&device, i); ret = (*h.cuDeviceGet)(&device, i);
if (ret != CUDA_SUCCESS) { if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "nvcuda device failed to initialize"); snprintf(buf, buflen, "cuda driver library device failed to initialize");
resp->err = strdup(buf); resp->err = strdup(buf);
return; return;
} }
@@ -168,14 +169,14 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
// To get memory we have to set (and release) a context // To get memory we have to set (and release) a context
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device); ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
if (ret != CUDA_SUCCESS) { if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "nvcuda failed to get device context %d", ret); snprintf(buf, buflen, "cuda driver library failed to get device context %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
return; return;
} }
ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total); ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
if (ret != CUDA_SUCCESS) { if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret); snprintf(buf, buflen, "cuda driver library device memory info lookup failure %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
// Best effort on failure... // Best effort on failure...
(*h.cuCtxDestroy)(ctx); (*h.cuCtxDestroy)(ctx);
@@ -193,7 +194,7 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
ret = (*h.cuCtxDestroy)(ctx); ret = (*h.cuCtxDestroy)(ctx);
if (ret != CUDA_SUCCESS) { if (ret != CUDA_SUCCESS) {
LOG(1, "nvcuda failed to release device context %d", ret); LOG(1, "cuda driver library failed to release device context %d", ret);
} }
} }
@@ -206,7 +207,7 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
ret = (*h.cuDeviceGet)(&device, i); ret = (*h.cuDeviceGet)(&device, i);
if (ret != CUDA_SUCCESS) { if (ret != CUDA_SUCCESS) {
LOG(1, "nvcuda device failed to initialize"); LOG(1, "cuda driver library device failed to initialize");
return; return;
} }
@@ -214,13 +215,13 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
// To get memory we have to set (and release) a context // To get memory we have to set (and release) a context
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device); ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
if (ret != CUDA_SUCCESS) { if (ret != CUDA_SUCCESS) {
LOG(1, "nvcuda failed to get device context %d", ret); LOG(1, "cuda driver library failed to get device context %d", ret);
return; return;
} }
ret = (*h.cuMemGetInfo_v2)(free, total); ret = (*h.cuMemGetInfo_v2)(free, total);
if (ret != CUDA_SUCCESS) { if (ret != CUDA_SUCCESS) {
LOG(1, "nvcuda device memory info lookup failure %d", ret); LOG(1, "cuda driver library device memory info lookup failure %d", ret);
// Best effort on failure... // Best effort on failure...
(*h.cuCtxDestroy)(ctx); (*h.cuCtxDestroy)(ctx);
return; return;
@@ -228,12 +229,12 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
ret = (*h.cuCtxDestroy)(ctx); ret = (*h.cuCtxDestroy)(ctx);
if (ret != CUDA_SUCCESS) { if (ret != CUDA_SUCCESS) {
LOG(1, "nvcuda failed to release device context %d", ret); LOG(1, "cuda driver library failed to release device context %d", ret);
} }
} }
void nvcuda_release(nvcuda_handle_t h) { void nvcuda_release(nvcuda_handle_t h) {
LOG(h.verbose, "releasing nvcuda library\n"); LOG(h.verbose, "releasing cuda driver library\n");
UNLOAD_LIBRARY(h.handle); UNLOAD_LIBRARY(h.handle);
// TODO and other context release logic? // TODO and other context release logic?
h.handle = NULL; h.handle = NULL;

View File

@@ -7,9 +7,12 @@
typedef enum cudaError_enum { typedef enum cudaError_enum {
CUDA_SUCCESS = 0, CUDA_SUCCESS = 0,
CUDA_ERROR_INVALID_VALUE = 1, CUDA_ERROR_INVALID_VALUE = 1,
CUDA_ERROR_MEMORY_ALLOCATION = 2, CUDA_ERROR_OUT_OF_MEMORY = 2,
CUDA_ERROR_NOT_INITIALIZED = 3, CUDA_ERROR_NOT_INITIALIZED = 3,
CUDA_ERROR_INSUFFICIENT_DRIVER = 35, CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
CUDA_ERROR_NO_DEVICE = 100,
CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803,
CUDA_ERROR_UNKNOWN = 999,
// Other values omitted for now... // Other values omitted for now...
} CUresult; } CUresult;
@@ -64,6 +67,7 @@ typedef struct nvcuda_init_resp {
char *err; // If err is non-null handle is invalid char *err; // If err is non-null handle is invalid
nvcuda_handle_t ch; nvcuda_handle_t ch;
int num_devices; int num_devices;
CUresult cudaErr;
} nvcuda_init_resp_t; } nvcuda_init_resp_t;
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp); void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);

View File

@@ -29,6 +29,11 @@ type GpuInfo struct {
// Extra environment variables specific to the GPU as list of [key,value] // Extra environment variables specific to the GPU as list of [key,value]
EnvWorkarounds [][2]string `json:"envs,omitempty"` EnvWorkarounds [][2]string `json:"envs,omitempty"`
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
// the FreeMemory is best effort, and may over or under report actual memory usage
// False indicates FreeMemory can generally be trusted on this GPU
UnreliableFreeMemory bool
// GPU information // GPU information
ID string `json:"gpu_id"` // string to use for selection of this specific GPU ID string `json:"gpu_id"` // string to use for selection of this specific GPU
Name string `json:"name"` // user friendly name if available Name string `json:"name"` // user friendly name if available

View File

@@ -1,14 +1,13 @@
set(TARGET ollama_llama_server)
set(TARGET ollama_llama_server) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h) install(TARGETS ${TARGET} RUNTIME)
install(TARGETS ${TARGET} RUNTIME) target_compile_definitions(${TARGET} PRIVATE
target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}> )
) target_link_libraries(${TARGET} PRIVATE ggml llama common llava ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) if (WIN32)
if (WIN32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) endif()
endif()
target_compile_features(${TARGET} PRIVATE cxx_std_11) target_compile_features(${TARGET} PRIVATE cxx_std_11)

View File

@@ -1382,12 +1382,50 @@ struct llama_server_context
} }
} }
std::string common_prefix(const std::string& str1, const std::string& str2) {
auto mismatch_pair = std::mismatch(str1.begin(), str1.end(), str2.begin());
return std::string(str1.begin(), mismatch_pair.first);
}
// Find the slot that has the greatest common prefix
server_slot *prefix_slot(const json &prompt) {
if (!prompt.is_string()) {
return nullptr;
}
std::string prompt_str = prompt.get<std::string>();
server_slot *slot = nullptr;
size_t longest = 0;
for (server_slot &s : slots) {
if (s.available() && s.prompt.is_string()) {
std::string s_prompt = s.prompt.get<std::string>();
std::string prefix = common_prefix(s_prompt, prompt_str);
if (prefix.size() > longest) {
slot = &s;
longest = prefix.size();
}
}
}
if (!slot) {
return get_slot(-1);
}
LOG_DEBUG("slot with common prefix found", {{
"slot_id", slot->id,
"characters", longest
}});
return slot;
}
void process_single_task(task_server& task) void process_single_task(task_server& task)
{ {
switch (task.type) switch (task.type)
{ {
case TASK_TYPE_COMPLETION: { case TASK_TYPE_COMPLETION: {
server_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); server_slot *slot = prefix_slot(task.data["prompt"]);
if (slot == nullptr) if (slot == nullptr)
{ {
// if no slot is available, we defer this task for processing later // if no slot is available, we defer this task for processing later
@@ -1654,22 +1692,23 @@ struct llama_server_context
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
{ {
const int n_left = slot.n_ctx - slot.params.n_keep; const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2; const int n_shift = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; const int n_erase = slot.n_prompt_tokens - slot.params.n_keep - n_shift;
std::vector<llama_token> new_tokens( std::vector<llama_token> new_tokens(
prompt_tokens.begin(), prompt_tokens.begin(),
prompt_tokens.begin() + slot.params.n_keep); prompt_tokens.begin() + slot.params.n_keep);
new_tokens.insert( new_tokens.insert(
new_tokens.end(), new_tokens.end(),
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.begin() + slot.params.n_keep + n_erase,
prompt_tokens.end()); prompt_tokens.end());
LOG_VERBOSE("input truncated", { LOG_INFO("input truncated", {
{"n_ctx", slot.n_ctx}, {"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep}, {"n_keep", slot.params.n_keep},
{"n_left", n_left}, {"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, {"n_shift", n_shift},
{"n_erase", n_erase},
}); });
slot.truncated = true; slot.truncated = true;
prompt_tokens = new_tokens; prompt_tokens = new_tokens;
@@ -1704,7 +1743,7 @@ struct llama_server_context
slot.n_past -= 1; slot.n_past -= 1;
} }
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past; slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
if (slot.ga_n != 1) if (slot.ga_n != 1)
{ {

View File

@@ -18,16 +18,16 @@ sign() {
fi fi
} }
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on -DLLAMA_OPENMP=off" COMMON_DARWIN_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DGGML_METAL_EMBED_LIBRARY=on -DGGML_OPENMP=off"
case "${GOARCH}" in case "${GOARCH}" in
"amd64") "amd64")
COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_NATIVE=off" COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DGGML_METAL=off -DGGML_NATIVE=off"
# Static build for linking into the Go binary # Static build for linking into the Go binary
init_vars init_vars
CMAKE_TARGETS="--target llama --target ggml" CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DLLAMA_BLAS=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_BLAS=off -DGGML_ACCELERATE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}_static" BUILD_DIR="../build/darwin/${ARCH}_static"
echo "Building static library" echo "Building static library"
build build
@@ -37,7 +37,7 @@ case "${GOARCH}" in
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta) # CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
# #
init_vars init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_BLAS=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=off -DGGML_BLAS=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu" BUILD_DIR="../build/darwin/${ARCH}/cpu"
echo "Building LCD CPU" echo "Building LCD CPU"
build build
@@ -49,7 +49,7 @@ case "${GOARCH}" in
# Approximately 400% faster than LCD on same CPU # Approximately 400% faster than LCD on same CPU
# #
init_vars init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_BLAS=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=off -DGGML_BLAS=off -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx" BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
echo "Building AVX CPU" echo "Building AVX CPU"
build build
@@ -61,7 +61,7 @@ case "${GOARCH}" in
# Approximately 10% faster than AVX on same CPU # Approximately 10% faster than AVX on same CPU
# #
init_vars init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_BLAS=off -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=on -DGGML_BLAS=off -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2" BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU" echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation" EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
@@ -75,14 +75,14 @@ case "${GOARCH}" in
# Static build for linking into the Go binary # Static build for linking into the Go binary
init_vars init_vars
CMAKE_TARGETS="--target llama --target ggml" CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_BLAS=off -DCMAKE_SYSTEM_NAME=Darwin -DBUILD_SHARED_LIBS=off -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}_static" BUILD_DIR="../build/darwin/${ARCH}_static"
echo "Building static library" echo "Building static library"
build build
if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then
init_vars init_vars
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/metal" BUILD_DIR="../build/darwin/${ARCH}/metal"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders" EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
build build

View File

@@ -51,7 +51,7 @@ if [ -z "${CUDACXX}" ]; then
export CUDACXX=$(command -v nvcc) export CUDACXX=$(command -v nvcc)
fi fi
fi fi
COMMON_CMAKE_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DLLAMA_OPENMP=off" COMMON_CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off"
source $(dirname $0)/gen_common.sh source $(dirname $0)/gen_common.sh
init_vars init_vars
git_module_setup git_module_setup
@@ -64,7 +64,7 @@ if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ];
# Static build for linking into the Go binary # Static build for linking into the Go binary
init_vars init_vars
CMAKE_TARGETS="--target llama --target ggml" CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DLLAMA_OPENMP=off ${CMAKE_DEFS}" CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DGGML_NATIVE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}_static" BUILD_DIR="../build/linux/${ARCH}_static"
echo "Building static library" echo "Building static library"
build build
@@ -77,29 +77,29 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
init_vars init_vars
echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\"" echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}" CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cpu" BUILD_DIR="../build/linux/${ARCH}/cpu"
echo "Building custom CPU" echo "Building custom CPU"
build build
compress compress
else else
# Darwin Rosetta x86 emulation does NOT support AVX, AVX2, AVX512 # Darwin Rosetta x86 emulation does NOT support AVX, AVX2, AVX512
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer # -DGGML_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX) # -DGGML_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen # -DGGML_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver # -DGGML_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
# Note: the following seem to yield slower results than AVX2 - ymmv # Note: the following seem to yield slower results than AVX2 - ymmv
# -DLLAMA_AVX512 -- 2017 Intel Skylake and High End DeskTop (HEDT) # -DGGML_AVX512 -- 2017 Intel Skylake and High End DeskTop (HEDT)
# -DLLAMA_AVX512_VBMI -- 2018 Intel Cannon Lake # -DGGML_AVX512_VBMI -- 2018 Intel Cannon Lake
# -DLLAMA_AVX512_VNNI -- 2021 Intel Alder Lake # -DGGML_AVX512_VNNI -- 2021 Intel Alder Lake
COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_OPENMP=off" COMMON_CPU_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_OPENMP=off"
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
# #
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta) # CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
# #
init_vars init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cpu" BUILD_DIR="../build/linux/${ARCH}/cpu"
echo "Building LCD CPU" echo "Building LCD CPU"
build build
@@ -116,7 +116,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
# Approximately 400% faster than LCD on same CPU # Approximately 400% faster than LCD on same CPU
# #
init_vars init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cpu_avx" BUILD_DIR="../build/linux/${ARCH}/cpu_avx"
echo "Building AVX CPU" echo "Building AVX CPU"
build build
@@ -129,7 +129,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
# Approximately 10% faster than AVX on same CPU # Approximately 10% faster than AVX on same CPU
# #
init_vars init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cpu_avx2" BUILD_DIR="../build/linux/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU" echo "Building AVX2 CPU"
build build
@@ -170,15 +170,15 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
# #
# CUDA compute < 6.0 lacks proper FP16 support on ARM. # CUDA compute < 6.0 lacks proper FP16 support on ARM.
# Disabling has minimal performance effect while maintaining compatibility. # Disabling has minimal performance effect while maintaining compatibility.
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off" ARM64_DEFS="-DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_CUDA_F16=off"
fi fi
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\"" echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}" CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
echo "Building custom CUDA GPU" echo "Building custom CUDA GPU"
else else
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}" CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DGGML_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} -DCMAKE_LIBRARY_PATH=/usr/local/cuda/compat"
fi fi
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}" CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}" BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
@@ -216,7 +216,7 @@ if [ -z "${OLLAMA_SKIP_ONEAPI_GENERATE}" -a -d "${ONEAPI_ROOT}" ]; then
init_vars init_vars
source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI
CC=icx CC=icx
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL=ON -DLLAMA_SYCL_F16=OFF" CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL=ON -DGGML_SYCL_F16=OFF"
BUILD_DIR="../build/linux/${ARCH}/oneapi" BUILD_DIR="../build/linux/${ARCH}/oneapi"
EXTRA_LIBS="-fsycl -Wl,-rpath,${ONEAPI_ROOT}/compiler/latest/lib,-rpath,${ONEAPI_ROOT}/mkl/latest/lib,-rpath,${ONEAPI_ROOT}/tbb/latest/lib,-rpath,${ONEAPI_ROOT}/compiler/latest/opt/oclfpga/linux64/lib -lOpenCL -lmkl_core -lmkl_sycl_blas -lmkl_intel_ilp64 -lmkl_tbb_thread -ltbb" EXTRA_LIBS="-fsycl -Wl,-rpath,${ONEAPI_ROOT}/compiler/latest/lib,-rpath,${ONEAPI_ROOT}/mkl/latest/lib,-rpath,${ONEAPI_ROOT}/tbb/latest/lib,-rpath,${ONEAPI_ROOT}/compiler/latest/opt/oclfpga/linux64/lib -lOpenCL -lmkl_core -lmkl_sycl_blas -lmkl_intel_ilp64 -lmkl_tbb_thread -ltbb"
DEBUG_FLAGS="" # icx compiles with -O0 if we pass -g, so we must remove it DEBUG_FLAGS="" # icx compiles with -O0 if we pass -g, so we must remove it
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true) ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
fi fi
init_vars init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)" CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\"" echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""

View File

@@ -39,8 +39,8 @@ function init_vars {
} }
$script:cmakeDefs = @( $script:cmakeDefs = @(
"-DBUILD_SHARED_LIBS=on", "-DBUILD_SHARED_LIBS=on",
"-DLLAMA_NATIVE=off", "-DGGML_NATIVE=off",
"-DLLAMA_OPENMP=off" "-DGGML_OPENMP=off"
) )
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on") $script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower() $script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
@@ -182,9 +182,9 @@ function cleanup {
} }
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer # -DGGML_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen # -DGGML_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver # -DGGML_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
function build_static() { function build_static() {
@@ -204,13 +204,13 @@ function build_static() {
"-DCMAKE_C_COMPILER=gcc.exe", "-DCMAKE_C_COMPILER=gcc.exe",
"-DCMAKE_CXX_COMPILER=g++.exe", "-DCMAKE_CXX_COMPILER=g++.exe",
"-DBUILD_SHARED_LIBS=off", "-DBUILD_SHARED_LIBS=off",
"-DLLAMA_NATIVE=off", "-DGGML_NATIVE=off",
"-DLLAMA_AVX=off", "-DGGML_AVX=off",
"-DLLAMA_AVX2=off", "-DGGML_AVX2=off",
"-DLLAMA_AVX512=off", "-DGGML_AVX512=off",
"-DLLAMA_F16C=off", "-DGGML_F16C=off",
"-DLLAMA_FMA=off", "-DGGML_FMA=off",
"-DLLAMA_OPENMP=off") "-DGGML_OPENMP=off")
$script:buildDir="../build/windows/${script:ARCH}_static" $script:buildDir="../build/windows/${script:ARCH}_static"
write-host "Building static library" write-host "Building static library"
build build
@@ -224,7 +224,7 @@ function build_cpu($gen_arch) {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) { if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
# remaining llama.cpp builds use MSVC # remaining llama.cpp builds use MSVC
init_vars init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs $script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DGGML_AVX=off", "-DGGML_AVX2=off", "-DGGML_AVX512=off", "-DGGML_FMA=off", "-DGGML_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu" $script:buildDir="../build/windows/${script:ARCH}/cpu"
$script:distDir="$script:DIST_BASE\cpu" $script:distDir="$script:DIST_BASE\cpu"
write-host "Building LCD CPU" write-host "Building LCD CPU"
@@ -239,7 +239,7 @@ function build_cpu($gen_arch) {
function build_cpu_avx() { function build_cpu_avx() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) { if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
init_vars init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DGGML_AVX=on", "-DGGML_AVX2=off", "-DGGML_AVX512=off", "-DGGML_FMA=off", "-DGGML_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx" $script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
$script:distDir="$script:DIST_BASE\cpu_avx" $script:distDir="$script:DIST_BASE\cpu_avx"
write-host "Building AVX CPU" write-host "Building AVX CPU"
@@ -254,7 +254,7 @@ function build_cpu_avx() {
function build_cpu_avx2() { function build_cpu_avx2() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) { if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
init_vars init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DGGML_AVX=on", "-DGGML_AVX2=on", "-DGGML_AVX512=off", "-DGGML_FMA=on", "-DGGML_F16C=on") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2" $script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
$script:distDir="$script:DIST_BASE\cpu_avx2" $script:distDir="$script:DIST_BASE\cpu_avx2"
write-host "Building AVX2 CPU" write-host "Building AVX2 CPU"
@@ -279,9 +279,9 @@ function build_cuda() {
$script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT" $script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @( $script:cmakeDefs += @(
"-A", "x64", "-A", "x64",
"-DLLAMA_CUDA=ON", "-DGGML_CUDA=ON",
"-DLLAMA_AVX=on", "-DGGML_AVX=on",
"-DLLAMA_AVX2=off", "-DGGML_AVX2=off",
"-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR",
"-DCMAKE_CUDA_FLAGS=-t8", "-DCMAKE_CUDA_FLAGS=-t8",
"-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}" "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}"
@@ -319,7 +319,7 @@ function build_oneapi() {
$script:distDir ="$script:DIST_BASE\oneapi$script:ONEAPI_VARIANT" $script:distDir ="$script:DIST_BASE\oneapi$script:ONEAPI_VARIANT"
$script:cmakeDefs += @( $script:cmakeDefs += @(
"-G", "MinGW Makefiles", "-G", "MinGW Makefiles",
"-DLLAMA_SYCL=ON", "-DGGML_SYCL=ON",
"-DCMAKE_C_COMPILER=icx", "-DCMAKE_C_COMPILER=icx",
"-DCMAKE_CXX_COMPILER=icx", "-DCMAKE_CXX_COMPILER=icx",
"-DCMAKE_BUILD_TYPE=Release" "-DCMAKE_BUILD_TYPE=Release"
@@ -365,10 +365,11 @@ function build_rocm() {
"-G", "Ninja", "-G", "Ninja",
"-DCMAKE_C_COMPILER=clang.exe", "-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe", "-DCMAKE_CXX_COMPILER=clang++.exe",
"-DLLAMA_HIPBLAS=on", "-DGGML_HIPBLAS=on",
"-DLLAMA_CUDA_NO_PEER_COPY=on",
"-DHIP_PLATFORM=amd", "-DHIP_PLATFORM=amd",
"-DLLAMA_AVX=on", "-DGGML_AVX=on",
"-DLLAMA_AVX2=off", "-DGGML_AVX2=off",
"-DCMAKE_POSITION_INDEPENDENT_CODE=on", "-DCMAKE_POSITION_INDEPENDENT_CODE=on",
"-DAMDGPU_TARGETS=$(amdGPUs)", "-DAMDGPU_TARGETS=$(amdGPUs)",
"-DGPU_TARGETS=$(amdGPUs)" "-DGPU_TARGETS=$(amdGPUs)"

View File

@@ -53,7 +53,7 @@ func (llm *ggla) Tensors() Tensors {
return llm.tensors return llm.tensors
} }
func (llm *ggla) decode(rs io.ReadSeeker) error { func (llm *ggla) decode(rs io.ReadSeeker) (retErr error) {
var r uint32 var r uint32
if err := binary.Read(rs, binary.LittleEndian, &r); err != nil { if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
return err return err
@@ -69,9 +69,18 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
for { for {
var dims uint32 var dims uint32
if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil { if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err return err
} }
defer func() {
if errors.Is(retErr, io.EOF) {
retErr = io.ErrUnexpectedEOF
}
}()
var namesize uint32 var namesize uint32
if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil { if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil {
return err return err
@@ -108,7 +117,7 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
return err return err
} }
if _, err := rs.Seek((offset+31)&-32, io.SeekStart); err != nil { if _, err := rs.Seek((offset+31)&-32-offset, io.SeekCurrent); err != nil {
return err return err
} }

View File

@@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"io" "io"
"strings" "strings"
"github.com/ollama/ollama/util/bufioutil"
) )
type GGML struct { type GGML struct {
@@ -278,7 +280,18 @@ func DetectGGMLType(b []byte) string {
} }
} }
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { // DecodeGGML decodes a GGML model from the given reader.
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
// the maxArraySize is negative, all arrays are collected.
func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
if maxArraySize == 0 {
maxArraySize = 1024
}
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32 var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
return nil, 0, err return nil, 0, err
@@ -291,17 +304,15 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
case FILE_MAGIC_GGLA: case FILE_MAGIC_GGLA:
c = &containerGGLA{} c = &containerGGLA{}
case FILE_MAGIC_GGUF_LE: case FILE_MAGIC_GGUF_LE:
c = &containerGGUF{ByteOrder: binary.LittleEndian} c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
case FILE_MAGIC_GGUF_BE: case FILE_MAGIC_GGUF_BE:
c = &containerGGUF{ByteOrder: binary.BigEndian} c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
default: default:
return nil, 0, errors.New("invalid file magic") return nil, 0, errors.New("invalid file magic")
} }
model, err := c.Decode(rs) model, err := c.Decode(rs)
if errors.Is(err, io.EOF) { if err != nil {
// noop
} else if err != nil {
return nil, 0, err return nil, 0, err
} }
@@ -321,7 +332,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
embedding := llm.KV().EmbeddingLength() embedding := llm.KV().EmbeddingLength()
heads := llm.KV().HeadCount() heads := llm.KV().HeadCount()
headsKV := llm.KV().HeadCountKV() headsKV := llm.KV().HeadCountKV()
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any))) vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
embeddingHeads := llm.KV().EmbeddingHeadCount() embeddingHeads := llm.KV().EmbeddingHeadCount()
embeddingHeadsK := llm.KV().EmbeddingHeadCountK() embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
@@ -355,9 +366,18 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16), 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
) )
} }
case "gemma": case "gemma", "gemma2":
fullOffload = 4 * batch * (embedding + vocab) fullOffload = max(
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128 4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
)
partialOffload = max(
4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
4*embeddingHeadsK*context*8+
embedding*embeddingHeadsK*heads*9/16,
)
case "command-r": case "command-r":
fullOffload = max( fullOffload = max(
4*batch*(embedding+vocab), 4*batch*(embedding+vocab),

1
llm/ggml_test.go Normal file
View File

@@ -0,0 +1 @@
package llm

View File

@@ -3,11 +3,10 @@ package llm
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"encoding/json"
"fmt" "fmt"
"io" "io"
"strings" "strings"
"log/slog"
) )
type containerGGUF struct { type containerGGUF struct {
@@ -29,6 +28,12 @@ type containerGGUF struct {
NumTensor uint64 NumTensor uint64
NumKV uint64 NumKV uint64
} }
maxArraySize int
}
func (c *containerGGUF) canCollectArray(size int) bool {
return c.maxArraySize < 0 || size <= c.maxArraySize
} }
func (c *containerGGUF) Name() string { func (c *containerGGUF) Name() string {
@@ -54,7 +59,6 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
} }
model := newGGUF(c) model := newGGUF(c)
slog.Debug(fmt.Sprintf("model = %#v", model))
if err := model.Decode(rs); err != nil { if err := model.Decode(rs); err != nil {
return nil, err return nil, err
} }
@@ -85,6 +89,8 @@ type gguf struct {
tensors []*Tensor tensors []*Tensor
parameters uint64 parameters uint64
scratch [16 << 10]byte
} }
func newGGUF(container *containerGGUF) *gguf { func newGGUF(container *containerGGUF) *gguf {
@@ -181,34 +187,34 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
} }
// decode tensors // decode tensors
for i := 0; uint64(i) < llm.numTensor(); i++ { for range llm.numTensor() {
name, err := readGGUFString(llm, rs) name, err := readGGUFString(llm, rs)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to read tensor name: %w", err)
} }
// dims is the number of dimensions in the tensor // dims is the number of dimensions in the tensor
dims, err := readGGUF[uint32](llm, rs) dims, err := readGGUF[uint32](llm, rs)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to read tensor dimensions: %w", err)
} }
shape := [4]uint64{1, 1, 1, 1} shape := [4]uint64{1, 1, 1, 1}
for i := 0; uint32(i) < dims; i++ { for i := 0; uint32(i) < dims; i++ {
shape[i], err = readGGUF[uint64](llm, rs) shape[i], err = readGGUF[uint64](llm, rs)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to read tensor shape: %w", err)
} }
} }
kind, err := readGGUF[uint32](llm, rs) kind, err := readGGUF[uint32](llm, rs)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to read tensor kind: %w", err)
} }
offset, err := readGGUF[uint64](llm, rs) offset, err := readGGUF[uint64](llm, rs)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to read tensor offset: %w", err)
} }
tensor := Tensor{ tensor := Tensor{
@@ -230,24 +236,19 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
alignment = 32 alignment = 32
} }
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
padding := llm.padding(offset, int64(alignment))
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
return err
}
for _, tensor := range llm.tensors { for _, tensor := range llm.tensors {
if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil { offset, err := rs.Seek(0, io.SeekCurrent)
return err if err != nil {
return fmt.Errorf("failed to get current offset: %w", err)
} }
padding := llm.padding(int64(tensor.Size()), int64(alignment)) padding := llm.padding(offset, int64(alignment))
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil { if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
return err return fmt.Errorf("failed to seek to init padding: %w", err)
}
if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
return fmt.Errorf("failed to seek to tensor: %w", err)
} }
} }
@@ -285,22 +286,48 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
return b.String(), nil return b.String(), nil
} }
func discardGGUFString(llm *gguf, r io.Reader) error {
buf := llm.scratch[:8]
_, err := io.ReadFull(r, buf)
if err != nil {
return err
}
size := int(llm.ByteOrder.Uint64(buf))
for size > 0 {
n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
if err != nil {
return err
}
size -= n
}
return nil
}
func readGGUFString(llm *gguf, r io.Reader) (string, error) { func readGGUFString(llm *gguf, r io.Reader) (string, error) {
if llm.Version == 1 { if llm.Version == 1 {
return readGGUFV1String(llm, r) return readGGUFV1String(llm, r)
} }
var length uint64 buf := llm.scratch[:8]
if err := binary.Read(r, llm.ByteOrder, &length); err != nil { _, err := io.ReadFull(r, buf)
if err != nil {
return "", err return "", err
} }
var b bytes.Buffer length := int(llm.ByteOrder.Uint64(buf))
if _, err := io.CopyN(&b, r, int64(length)); err != nil { if length > len(llm.scratch) {
buf = make([]byte, length)
} else {
buf = llm.scratch[:length]
}
clear(buf)
_, err = io.ReadFull(r, buf)
if err != nil {
return "", err return "", err
} }
return string(buf), nil
return b.String(), nil
} }
func writeGGUFString(llm *gguf, w io.Writer, s string) error { func writeGGUFString(llm *gguf, w io.Writer, s string) error {
@@ -316,7 +343,16 @@ func writeGGUFString(llm *gguf, w io.Writer, s string) error {
return err return err
} }
func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) { type array struct {
size int
values []any
}
func (a *array) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
t, err := readGGUF[uint32](llm, r) t, err := readGGUF[uint32](llm, r)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -327,7 +363,12 @@ func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
return nil, err return nil, err
} }
for i := 0; uint32(i) < n; i++ { a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, 0, int(n))
}
for i := range n {
var e any var e any
switch t { switch t {
case ggufTypeUint8: case ggufTypeUint8:
@@ -361,13 +402,15 @@ func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
return nil, err return nil, err
} }
a = append(a, e) if a.values != nil {
a.values[i] = e
}
} }
return return a, nil
} }
func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) { func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
if llm.Version == 1 { if llm.Version == 1 {
return readGGUFV1Array(llm, r) return readGGUFV1Array(llm, r)
} }
@@ -382,7 +425,12 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
return nil, err return nil, err
} }
for i := 0; uint64(i) < n; i++ { a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, int(n))
}
for i := range n {
var e any var e any
switch t { switch t {
case ggufTypeUint8: case ggufTypeUint8:
@@ -408,7 +456,11 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
case ggufTypeBool: case ggufTypeBool:
e, err = readGGUF[bool](llm, r) e, err = readGGUF[bool](llm, r)
case ggufTypeString: case ggufTypeString:
e, err = readGGUFString(llm, r) if a.values != nil {
e, err = readGGUFString(llm, r)
} else {
err = discardGGUFString(llm, r)
}
default: default:
return nil, fmt.Errorf("invalid array type: %d", t) return nil, fmt.Errorf("invalid array type: %d", t)
} }
@@ -416,10 +468,12 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
return nil, err return nil, err
} }
a = append(a, e) if a.values != nil {
a.values[i] = e
}
} }
return return a, nil
} }
func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error { func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error {

View File

@@ -1,12 +1,13 @@
package llm package llm
// #cgo CFLAGS: -Illama.cpp // #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++ // #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++ // #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++ // #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++ // #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++ // #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++ // #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #include <stdlib.h> // #include <stdlib.h>
// #include "llama.h" // #include "llama.h"
import "C" import "C"

View File

@@ -22,13 +22,14 @@ func TestEstimateGPULayers(t *testing.T) {
defer f.Close() defer f.Close()
gguf := NewGGUFV3(binary.LittleEndian) gguf := NewGGUFV3(binary.LittleEndian)
inputLayerCount := 5 inputLayerCount := 5
tensors := []Tensor{ tensors := []Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, {Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, {Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, {Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, {Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
} }
assert.Len(t, tensors, inputLayerCount+1) assert.Len(t, tensors, inputLayerCount+1)
err = gguf.Encode(f, KV{ err = gguf.Encode(f, KV{
@@ -45,8 +46,10 @@ func TestEstimateGPULayers(t *testing.T) {
}, tensors) }, tensors)
require.NoError(t, err) require.NoError(t, err)
ggml, err := LoadModel(f.Name()) ggml, err := LoadModel(f.Name(), 0)
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
// Simple CPU scenario // Simple CPU scenario
gpus := []gpu.GpuInfo{ gpus := []gpu.GpuInfo{

View File

@@ -1,8 +1,8 @@
diff --git a/common/common.cpp b/common/common.cpp diff --git a/common/common.cpp b/common/common.cpp
index 73ff0e85..6adb1a92 100644 index 2c05a4d4..927f0e3d 100644
--- a/common/common.cpp --- a/common/common.cpp
+++ b/common/common.cpp +++ b/common/common.cpp
@@ -2447,6 +2447,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & @@ -2093,6 +2093,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.use_mmap = params.use_mmap; mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock; mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors; mparams.check_tensors = params.check_tensors;
@@ -12,10 +12,10 @@ index 73ff0e85..6adb1a92 100644
mparams.kv_overrides = NULL; mparams.kv_overrides = NULL;
} else { } else {
diff --git a/common/common.h b/common/common.h diff --git a/common/common.h b/common/common.h
index 58ed72f4..0bb2605e 100644 index 65c0ef81..ebca2c77 100644
--- a/common/common.h --- a/common/common.h
+++ b/common/common.h +++ b/common/common.h
@@ -180,6 +180,13 @@ struct gpt_params { @@ -184,6 +184,13 @@ struct gpt_params {
std::string mmproj = ""; // path to multimodal projector std::string mmproj = ""; // path to multimodal projector
std::vector<std::string> image; // path to image file(s) std::vector<std::string> image; // path to image file(s)
@@ -26,6 +26,6 @@ index 58ed72f4..0bb2605e 100644
+ // context pointer passed to the progress callback + // context pointer passed to the progress callback
+ void * progress_callback_user_data; + void * progress_callback_user_data;
+ +
// server params // embedding
int32_t port = 8080; // server listens on this network port bool embedding = false; // get only sentence embedding
int32_t timeout_read = 600; // http read timeout in seconds int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)

View File

@@ -1,17 +1,8 @@
From 544a2d2e646d39e878d87dfbb3398a356bc560ab Mon Sep 17 00:00:00 2001 diff --git a/src/llama.cpp b/src/llama.cpp
From: Michael Yang <mxyng@pm.me> index 73f52435..58a00fb1 100644
Date: Thu, 23 May 2024 11:18:45 -0700 --- a/src/llama.cpp
Subject: [PATCH] throw exception on load errors +++ b/src/llama.cpp
@@ -7241,7 +7241,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
---
llama.cpp | 25 ++++++++++++++++---------
1 file changed, 16 insertions(+), 9 deletions(-)
diff --git a/llama.cpp b/llama.cpp
index 15c66077..8ba90b6a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6346,7 +6346,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
} }
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
@@ -20,7 +11,7 @@ index 15c66077..8ba90b6a 100644
} }
return 0; return 0;
@@ -15600,16 +15600,23 @@ struct llama_model * llama_load_model_from_file( @@ -17564,16 +17564,23 @@ struct llama_model * llama_load_model_from_file(
} }
model->rpc_servers.push_back(servers); model->rpc_servers.push_back(servers);
} }
@@ -52,6 +43,3 @@ index 15c66077..8ba90b6a 100644
} }
return model; return model;
--
2.45.1

View File

@@ -1,7 +1,7 @@
diff --git a/ggml-metal.m b/ggml-metal.m diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 0207b787..b5e9884b 100644 index 0207b787..b5e9884b 100644
--- a/ggml-metal.m --- a/ggml/src/ggml-metal.m
+++ b/ggml-metal.m +++ b/ggml/src/ggml-metal.m
@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute( @@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
// to the matrix-vector kernel // to the matrix-vector kernel
int ne11_mm_min = 1; int ne11_mm_min = 1;

View File

@@ -1,11 +1,11 @@
diff --git a/llama.cpp b/llama.cpp diff --git a/src/llama.cpp b/src/llama.cpp
index 61948751..4b72a293 100644 index 2b9ace28..172640e2 100644
--- a/llama.cpp --- a/src/llama.cpp
+++ b/llama.cpp +++ b/src/llama.cpp
@@ -4824,16 +4824,7 @@ static void llm_load_vocab( @@ -5357,16 +5357,7 @@ static void llm_load_vocab(
// for now, only BPE models have pre-tokenizers
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
vocab.tokenizer_add_space_prefix = false;
vocab.tokenizer_clean_spaces = true;
- if (tokenizer_pre.empty()) { - if (tokenizer_pre.empty()) {
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); - LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__); - LLAMA_LOG_WARN("%s: \n", __func__);
@@ -20,13 +20,13 @@ index 61948751..4b72a293 100644
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if ( } else if (
tokenizer_pre == "llama3" || tokenizer_pre == "llama3" ||
@@ -4888,7 +4879,8 @@ static void llm_load_vocab( @@ -5439,7 +5430,8 @@ static void llm_load_vocab(
tokenizer_pre == "poro-chat") { tokenizer_pre == "jais") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
} else { } else {
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); - throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} }
} else { } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;

View File

@@ -1,7 +1,7 @@
diff --git a/llama.cpp b/llama.cpp diff --git a/src/llama.cpp b/src/llama.cpp
index 40d2ec2c..f34eb79a 100644 index 40d2ec2c..f34eb79a 100644
--- a/llama.cpp --- a/src/llama.cpp
+++ b/llama.cpp +++ b/src/llama.cpp
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv( @@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);

View File

@@ -0,0 +1,45 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 1fe2b9f7..a43312a7 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -13689,7 +13689,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead
- const bool has_logits = !cparams.embeddings;
+ const bool has_logits = cparams.causal_attn;
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
@@ -13959,17 +13959,25 @@ static int llama_decode_internal(
// no output
res = nullptr;
embd = nullptr;
- } else if (cparams.embeddings) {
- res = nullptr; // do not extract logits for embedding case
- embd = gf->nodes[gf->n_nodes - 1];
- if (strcmp(embd->name, "result_embd_pooled") != 0) {
- embd = gf->nodes[gf->n_nodes - 2];
+ }
+
+ if (cparams.embeddings) {
+ for (int i = gf->n_nodes - 1; i >= 0; --i) {
+ embd = gf->nodes[i];
+ if (strcmp(embd->name, "result_embd_pooled") == 0) {
+ break;
+ }
}
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
- } else {
+ } else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
+
+ if (!cparams.causal_attn) {
+ res = nullptr; // do not extract logits when not needed
+ }
+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
ggml_backend_sched_alloc_graph(lctx.sched, gf);

View File

@@ -0,0 +1,42 @@
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index 95fbe3d0..5a02a6ec 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -32,6 +33,14 @@
#include <cinttypes>
#include <limits>
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+ #define NOMINMAX
+#endif
+#include <windows.h>
+#endif
+
//#define CLIP_DEBUG_FUNCTIONS
// RGB uint8 image
@@ -1055,7 +1064,22 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
return nullptr;
}
+#ifdef _WIN32
+ int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0);
+ if (!wlen) {
+ return NULL;
+ }
+ wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
+ wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wbuf, wlen);
+ if (!wlen) {
+ free(wbuf);
+ return NULL;
+ }
+ auto fin = std::ifstream(wbuf, std::ios::binary);
+ free(wbuf);
+#else
auto fin = std::ifstream(fname, std::ios::binary);
+#endif
if (!fin) {
LOG_TEE("cannot open model file for loading tensors\n");
clip_free(new_clip);

View File

@@ -0,0 +1,60 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 721b8f4e..cfe7ac40 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -8420,14 +8420,14 @@ struct llm_build_context {
}
struct ggml_tensor * build_inp_mean() {
- lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
+ lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, cparams.n_seq_max);
cb(lctx.inp_mean, "inp_mean", -1);
ggml_set_input(lctx.inp_mean);
return lctx.inp_mean;
}
struct ggml_tensor * build_inp_cls() {
- lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_seq_max);
cb(lctx.inp_cls, "inp_cls", -1);
ggml_set_input(lctx.inp_cls);
return lctx.inp_cls;
@@ -13847,19 +13847,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
float * data = (float *) lctx.inp_mean->data;
- memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
+ memset(lctx.inp_mean->data, 0, n_tokens * cparams.n_seq_max * ggml_element_size(lctx.inp_mean));
std::vector<uint64_t> sum(n_tokens, 0);
for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
-
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
-
sum[seq_id] += 1;
}
- std::vector<float> div(n_tokens, 0.0f);
- for (int i = 0; i < n_tokens; ++i) {
+ std::vector<float> div(cparams.n_seq_max, 0.0f);
+ for (uint32_t i = 0; i < cparams.n_seq_max; ++i) {
const uint64_t s = sum[i];
if (s > 0) {
div[i] = 1.0f/float(s);
@@ -13879,14 +13876,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
- memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
+ memset(lctx.inp_cls->data, 0, cparams.n_seq_max * ggml_element_size(lctx.inp_cls));
for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
const llama_pos pos = batch.pos[i];
-
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
-
if (pos == 0) {
data[seq_id] = i;
}

View File

@@ -38,7 +38,7 @@ func Init() error {
} }
var variants []string var variants []string
for v := range availableServers() { for v := range getAvailableServers() {
variants = append(variants, v) variants = append(variants, v)
} }
slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants)) slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
@@ -50,7 +50,7 @@ func Init() error {
// binary names may contain an optional variant separated by '_' // binary names may contain an optional variant separated by '_'
// For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2" // For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
// Any library without a variant is the lowest common denominator // Any library without a variant is the lowest common denominator
func availableServers() map[string]string { func getAvailableServers() map[string]string {
payloadsDir, err := gpu.PayloadsDir() payloadsDir, err := gpu.PayloadsDir()
if err != nil { if err != nil {
slog.Error("payload lookup error", "error", err) slog.Error("payload lookup error", "error", err)
@@ -80,7 +80,7 @@ func availableServers() map[string]string {
// TODO - switch to metadata based mapping // TODO - switch to metadata based mapping
func serversForGpu(info gpu.GpuInfo) []string { func serversForGpu(info gpu.GpuInfo) []string {
// glob workDir for files that start with ollama_ // glob workDir for files that start with ollama_
availableServers := availableServers() availableServers := getAvailableServers()
requested := info.Library requested := info.Library
if info.Variant != gpu.CPUCapabilityNone { if info.Variant != gpu.CPUCapabilityNone {
requested += "_" + info.Variant.String() requested += "_" + info.Variant.String()
@@ -115,27 +115,29 @@ func serversForGpu(info gpu.GpuInfo) []string {
servers = append(servers, alt...) servers = append(servers, alt...)
} }
// Load up the best CPU variant if not primary requested if !(runtime.GOOS == "darwin" && runtime.GOARCH == "arm64") {
if info.Library != "cpu" { // Load up the best CPU variant if not primary requested
variant := gpu.GetCPUCapability() if info.Library != "cpu" {
// If no variant, then we fall back to default variant := gpu.GetCPUCapability()
// If we have a variant, try that if we find an exact match // If no variant, then we fall back to default
// Attempting to run the wrong CPU instructions will panic the // If we have a variant, try that if we find an exact match
// process // Attempting to run the wrong CPU instructions will panic the
if variant != gpu.CPUCapabilityNone { // process
for cmp := range availableServers { if variant != gpu.CPUCapabilityNone {
if cmp == "cpu_"+variant.String() { for cmp := range availableServers {
servers = append(servers, cmp) if cmp == "cpu_"+variant.String() {
break servers = append(servers, cmp)
break
}
} }
} else {
servers = append(servers, "cpu")
} }
} else {
servers = append(servers, "cpu")
} }
}
if len(servers) == 0 { if len(servers) == 0 {
servers = []string{"cpu"} servers = []string{"cpu"}
}
} }
return servers return servers
@@ -147,7 +149,7 @@ func serverForCpu() string {
return "metal" return "metal"
} }
variant := gpu.GetCPUCapability() variant := gpu.GetCPUCapability()
availableServers := availableServers() availableServers := getAvailableServers()
if variant != gpu.CPUCapabilityNone { if variant != gpu.CPUCapabilityNone {
for cmp := range availableServers { for cmp := range availableServers {
if cmp == "cpu_"+variant.String() { if cmp == "cpu_"+variant.String() {

View File

@@ -60,7 +60,12 @@ type llmServer struct {
sem *semaphore.Weighted sem *semaphore.Weighted
} }
func LoadModel(model string) (*GGML, error) { // LoadModel will load a model from disk. The model must be in the GGML format.
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
// the maxArraySize is negative, all arrays are collected.
func LoadModel(model string, maxArraySize int) (*GGML, error) {
if _, err := os.Stat(model); err != nil { if _, err := os.Stat(model); err != nil {
return nil, err return nil, err
} }
@@ -71,17 +76,27 @@ func LoadModel(model string) (*GGML, error) {
} }
defer f.Close() defer f.Close()
ggml, _, err := DecodeGGML(f) ggml, _, err := DecodeGGML(f, maxArraySize)
return ggml, err return ggml, err
} }
// NewLlamaServer will run a server for the given GPUs // NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family. // The gpu list must be a single family.
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) { func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var err error var err error
var cpuRunner string var cpuRunner string
var estimate MemoryEstimate var estimate MemoryEstimate
var systemMemory uint64 var systemTotalMemory uint64
var systemFreeMemory uint64
systemMemInfo, err := gpu.GetCPUMem()
if err != nil {
slog.Error("failed to lookup system memory", "error", err)
} else {
systemTotalMemory = systemMemInfo.TotalMemory
systemFreeMemory = systemMemInfo.FreeMemory
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
}
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info // If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
if opts.NumGPU == 0 { if opts.NumGPU == 0 {
@@ -91,19 +106,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
cpuRunner = serverForCpu() cpuRunner = serverForCpu()
estimate = EstimateGPULayers(gpus, ggml, projectors, opts) estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
} else { } else {
if gpus[0].Library == "metal" {
memInfo, err := gpu.GetCPUMem()
if err != nil {
slog.Error("failed to lookup system memory", "error", err)
} else {
systemMemory = memInfo.TotalMemory
slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
}
}
estimate = EstimateGPULayers(gpus, ggml, projectors, opts) estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
switch { switch {
case gpus[0].Library == "metal" && estimate.VRAMSize > systemMemory: case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
// disable partial offloading when model is greater than total system memory as this // disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system // can lead to locking up the system
opts.NumGPU = 0 opts.NumGPU = 0
@@ -125,7 +131,20 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided") return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
} }
availableServers := availableServers() availableServers := getAvailableServers()
if len(availableServers) == 0 {
if runtime.GOOS != "windows" {
slog.Warn("llama server binary disappeared, reinitializing payloads")
err = Init()
if err != nil {
slog.Warn("failed to reinitialize payloads", "error", err)
return nil, err
}
availableServers = getAvailableServers()
} else {
return nil, finalErr
}
}
var servers []string var servers []string
if cpuRunner != "" { if cpuRunner != "" {
servers = []string{cpuRunner} servers = []string{cpuRunner}
@@ -202,7 +221,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
if g.Library == "metal" && if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 && uint64(opts.NumGPU) > 0 &&
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 { uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
opts.UseMMap = api.TriStateFalse opts.UseMMap = new(bool)
*opts.UseMMap = false
} }
} }
@@ -211,7 +231,12 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} }
// Windows CUDA should not use mmap for best performance // Windows CUDA should not use mmap for best performance
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda") || opts.UseMMap == api.TriStateFalse { // Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) ||
(gpus[0].Library == "cpu" && opts.UseMMap == nil) ||
(opts.UseMMap != nil && !*opts.UseMMap) {
params = append(params, "--no-mmap") params = append(params, "--no-mmap")
} }
@@ -223,15 +248,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa") params = append(params, "--numa")
} }
numParallel := envconfig.NumParallel
// TODO (jmorganca): multimodal models don't support parallel yet
// see https://github.com/ollama/ollama/issues/4165
if len(projectors) > 0 {
numParallel = 1
slog.Warn("multimodal models don't support parallel requests yet")
}
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
if estimate.TensorSplit != "" { if estimate.TensorSplit != "" {
@@ -408,7 +424,7 @@ func projectorMemoryRequirements(filename string) uint64 {
} }
defer file.Close() defer file.Close()
ggml, _, err := DecodeGGML(file) ggml, _, err := DecodeGGML(file, 0)
if err != nil { if err != nil {
return 0 return 0
} }
@@ -558,6 +574,9 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
if s.status != nil && s.status.LastErrMsg != "" { if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg msg = s.status.LastErrMsg
} }
if strings.Contains(msg, "unknown model") {
return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
}
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg) return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
default: default:
} }
@@ -660,7 +679,7 @@ type CompletionRequest struct {
Prompt string Prompt string
Format string Format string
Images []ImageData Images []ImageData
Options api.Options Options *api.Options
} }
type CompletionResponse struct { type CompletionResponse struct {
@@ -680,10 +699,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
} }
defer s.sem.Release(1) defer s.sem.Release(1)
// only allow maximum 10 "context shifts" to avoid infinite generation // put an upper limit on num_predict to avoid the model running on forever
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx { if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx req.Options.NumPredict = 10 * s.options.NumCtx
slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
} }
request := map[string]any{ request := map[string]any{

View File

@@ -25,6 +25,7 @@ var errorPrefixes = []string{
"CUDA error", "CUDA error",
"cudaMalloc failed", "cudaMalloc failed",
"\"ERR\"", "\"ERR\"",
"error loading model",
} }
func (w *StatusWriter) Write(b []byte) (int, error) { func (w *StatusWriter) Write(b []byte) (int, error) {

View File

@@ -12,6 +12,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
) )
type Error struct { type Error struct {
@@ -42,6 +43,12 @@ type ChunkChoice struct {
FinishReason *string `json:"finish_reason"` FinishReason *string `json:"finish_reason"`
} }
type CompleteChunkChoice struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason *string `json:"finish_reason"`
}
type Usage struct { type Usage struct {
PromptTokens int `json:"prompt_tokens"` PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"` CompletionTokens int `json:"completion_tokens"`
@@ -85,6 +92,51 @@ type ChatCompletionChunk struct {
Choices []ChunkChoice `json:"choices"` Choices []ChunkChoice `json:"choices"`
} }
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
}
type Completion struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []CompleteChunkChoice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
}
type CompletionChunk struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []CompleteChunkChoice `json:"choices"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
}
type Model struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type ListCompletion struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
func NewError(code int, message string) ErrorResponse { func NewError(code int, message string) ErrorResponse {
var etype string var etype string
switch code { switch code {
@@ -145,7 +197,79 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
} }
} }
func fromRequest(r ChatCompletionRequest) api.ChatRequest { func toCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
Object: "text_completion",
Created: r.CreatedAt.Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []CompleteChunkChoice{{
Text: r.Response,
Index: 0,
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
}
}
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
return CompletionChunk{
Id: id,
Object: "text_completion",
Created: time.Now().Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []CompleteChunkChoice{{
Text: r.Response,
Index: 0,
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
}
}
func toListCompletion(r api.ListResponse) ListCompletion {
var data []Model
for _, m := range r.Models {
data = append(data, Model{
Id: m.Name,
Object: "model",
Created: m.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m.Name).Namespace,
})
}
return ListCompletion{
Object: "list",
Data: data,
}
}
func toModel(r api.ShowResponse, m string) Model {
return Model{
Id: m,
Object: "model",
Created: r.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m).Namespace,
}
}
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
var messages []api.Message var messages []api.Message
for _, msg := range r.Messages { for _, msg := range r.Messages {
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content}) messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
@@ -156,7 +280,7 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
switch stop := r.Stop.(type) { switch stop := r.Stop.(type) {
case string: case string:
options["stop"] = []string{stop} options["stop"] = []string{stop}
case []interface{}: case []any:
var stops []string var stops []string
for _, s := range stop { for _, s := range stop {
if str, ok := s.(string); ok { if str, ok := s.(string); ok {
@@ -208,13 +332,82 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
} }
} }
type writer struct { func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
stream bool options := make(map[string]any)
id string
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
}
}
options["stop"] = stops
}
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens
}
if r.Temperature != nil {
options["temperature"] = *r.Temperature * 2.0
} else {
options["temperature"] = 1.0
}
if r.Seed != nil {
options["seed"] = *r.Seed
}
options["frequency_penalty"] = r.FrequencyPenalty * 2.0
options["presence_penalty"] = r.PresencePenalty * 2.0
if r.TopP != 0.0 {
options["top_p"] = r.TopP
} else {
options["top_p"] = 1.0
}
return api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
}, nil
}
type BaseWriter struct {
gin.ResponseWriter gin.ResponseWriter
} }
func (w *writer) writeError(code int, data []byte) (int, error) { type ChatWriter struct {
stream bool
id string
BaseWriter
}
type CompleteWriter struct {
stream bool
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
var serr api.StatusError var serr api.StatusError
err := json.Unmarshal(data, &serr) err := json.Unmarshal(data, &serr)
if err != nil { if err != nil {
@@ -230,7 +423,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) {
return len(data), nil return len(data), nil
} }
func (w *writer) writeResponse(data []byte) (int, error) { func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse) err := json.Unmarshal(data, &chatResponse)
if err != nil { if err != nil {
@@ -270,7 +463,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
return len(data), nil return len(data), nil
} }
func (w *writer) Write(data []byte) (int, error) { func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.ResponseWriter.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(code, data) return w.writeError(code, data)
@@ -279,7 +472,176 @@ func (w *writer) Write(data []byte) (int, error) {
return w.writeResponse(data) return w.writeResponse(data)
} }
func Middleware() gin.HandlerFunc { func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
// response writer
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := fromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
var req ChatCompletionRequest var req ChatCompletionRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
@@ -294,17 +656,17 @@ func Middleware() gin.HandlerFunc {
} }
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil { if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return return
} }
c.Request.Body = io.NopCloser(&b) c.Request.Body = io.NopCloser(&b)
w := &writer{ w := &ChatWriter{
ResponseWriter: c.Writer, BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream, stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
} }
c.Writer = w c.Writer = w

271
openai/openai_test.go Normal file
View File

@@ -0,0 +1,271 @@
package openai
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/assert"
)
func TestMiddlewareRequests(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *http.Request)
}
var capturedRequest *http.Request
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request
c.Next()
}
}
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
Handler: CompletionsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var genReq api.GenerateRequest
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
}
if genReq.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
}
if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest)
})
}
}
func TestMiddlewareResponses(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
TestPath string
Handler func() gin.HandlerFunc
Endpoint func(c *gin.Context)
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
}
testCases := []testCase{
{
Name: "completions handler error forwarding",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
t.Fatalf("error was not forwarded")
}
},
},
{
Name: "list handler",
Method: http.MethodGet,
Path: "/api/tags",
TestPath: "/api/tags",
Handler: ListMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{
{
Name: "Test Model",
},
},
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var listResp ListCompletion
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
if listResp.Object != "list" {
t.Fatalf("expected list, got %s", listResp.Object)
}
if len(listResp.Data) != 1 {
t.Fatalf("expected 1, got %d", len(listResp.Data))
}
if listResp.Data[0].Id != "Test Model" {
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
}
},
},
{
Name: "retrieve model",
Method: http.MethodGet,
Path: "/api/show/:model",
TestPath: "/api/show/test-model",
Handler: RetrieveMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
var retrieveResp Model
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
t.Fatal(err)
}
if retrieveResp.Object != "model" {
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
}
if retrieveResp.Id != "test-model" {
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, tc.Endpoint)
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, resp)
})
}
}

View File

@@ -124,7 +124,7 @@ func ParseFile(r io.Reader) (*File, error) {
case stateComment, stateNil: case stateComment, stateNil:
// pass // pass
case stateValue: case stateValue:
s, ok := unquote(b.String()) s, ok := unquote(strings.TrimSpace(b.String()))
if !ok || isSpace(r) { if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil { if _, err := b.WriteRune(r); err != nil {
return nil, err return nil, err
@@ -158,7 +158,7 @@ func ParseFile(r io.Reader) (*File, error) {
case stateComment, stateNil: case stateComment, stateNil:
// pass; nothing to flush // pass; nothing to flush
case stateValue: case stateValue:
s, ok := unquote(b.String()) s, ok := unquote(strings.TrimSpace(b.String()))
if !ok { if !ok {
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
} }

View File

@@ -22,7 +22,13 @@ ADAPTER adapter1
LICENSE MIT LICENSE MIT
PARAMETER param1 value1 PARAMETER param1 value1
PARAMETER param2 value2 PARAMETER param2 value2
TEMPLATE template1 TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>"""
` `
reader := strings.NewReader(input) reader := strings.NewReader(input)
@@ -36,7 +42,40 @@ TEMPLATE template1
{Name: "license", Args: "MIT"}, {Name: "license", Args: "MIT"},
{Name: "param1", Args: "value1"}, {Name: "param1", Args: "value1"},
{Name: "param2", Args: "value2"}, {Name: "param2", Args: "value2"},
{Name: "template", Args: "template1"}, {Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
}
assert.Equal(t, expectedCommands, modelfile.Commands)
}
func TestParseFileTrimSpace(t *testing.T) {
input := `
FROM " model 1"
ADAPTER adapter3
LICENSE "MIT "
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|> """
`
reader := strings.NewReader(input)
modelfile, err := ParseFile(reader)
require.NoError(t, err)
expectedCommands := []Command{
{Name: "model", Args: " model 1"},
{Name: "adapter", Args: "adapter3"},
{Name: "license", Args: "MIT "},
{Name: "param1", Args: "value1"},
{Name: "param2", Args: "value2"},
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
} }
assert.Equal(t, expectedCommands, modelfile.Commands) assert.Equal(t, expectedCommands, modelfile.Commands)
@@ -48,6 +87,26 @@ func TestParseFileFrom(t *testing.T) {
expected []Command expected []Command
err error err error
}{ }{
{
"FROM \"FOO BAR \"",
[]Command{{Name: "model", Args: "FOO BAR "}},
nil,
},
{
"FROM \"FOO BAR\"\nPARAMETER param1 value1",
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
nil,
},
{
"FROM FOOO BAR ",
[]Command{{Name: "model", Args: "FOOO BAR"}},
nil,
},
{
"FROM /what/is/the path ",
[]Command{{Name: "model", Args: "/what/is/the path"}},
nil,
},
{ {
"FROM foo", "FROM foo",
[]Command{{Name: "model", Args: "foo"}}, []Command{{Name: "model", Args: "foo"}},
@@ -86,6 +145,11 @@ func TestParseFileFrom(t *testing.T) {
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}}, []Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
nil, nil,
}, },
{
"PARAMETER what the \nFROM lemons make lemonade ",
[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
nil,
},
} }
for _, c := range cases { for _, c := range cases {
@@ -399,7 +463,7 @@ func TestParseFileParameters(t *testing.T) {
"mirostat_eta 1.0": {"mirostat_eta", "1.0"}, "mirostat_eta 1.0": {"mirostat_eta", "1.0"},
"penalize_newline true": {"penalize_newline", "true"}, "penalize_newline true": {"penalize_newline", "true"},
"stop ### User:": {"stop", "### User:"}, "stop ### User:": {"stop", "### User:"},
"stop ### User: ": {"stop", "### User: "}, "stop ### User: ": {"stop", "### User:"},
"stop \"### User:\"": {"stop", "### User:"}, "stop \"### User:\"": {"stop", "### User:"},
"stop \"### User: \"": {"stop", "### User: "}, "stop \"### User: \"": {"stop", "### User: "},
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"}, "stop \"\"\"### User:\"\"\"": {"stop", "### User:"},

View File

@@ -31,6 +31,10 @@ func NewSpinner(message string) *Spinner {
return s return s
} }
func (s *Spinner) SetMessage(message string) {
s.message = message
}
func (s *Spinner) String() string { func (s *Spinner) String() string {
var sb strings.Builder var sb strings.Builder
if len(s.message) > 0 { if len(s.message) > 0 {

View File

@@ -279,7 +279,7 @@ if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\
case $OS_NAME in case $OS_NAME in
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;; centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;; rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
fedora) [ $OS_VERSION -lt '37' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '37';; fedora) [ $OS_VERSION -lt '39' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '39';;
amzn) install_cuda_driver_yum 'fedora' '37' ;; amzn) install_cuda_driver_yum 'fedora' '37' ;;
debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;; debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;; ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;

View File

@@ -6,10 +6,21 @@ set -ex
MACHINE=$(uname -m) MACHINE=$(uname -m)
if grep -i "centos" /etc/system-release >/dev/null; then if grep -i "centos" /etc/system-release >/dev/null; then
# As of 7/1/2024 mirrorlist.centos.org has been taken offline, so adjust accordingly
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
# Centos 7 derivatives have too old of a git version to run our generate script # Centos 7 derivatives have too old of a git version to run our generate script
# uninstall and ignore failures # uninstall and ignore failures
yum remove -y git yum remove -y git
yum -y install epel-release centos-release-scl yum -y install epel-release centos-release-scl
# The release packages reinstate the mirrors, undo that again
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
yum -y install dnf yum -y install dnf
if [ "${MACHINE}" = "x86_64" ]; then if [ "${MACHINE}" = "x86_64" ]; then
yum -y install https://repo.ius.io/ius-release-el7.rpm yum -y install https://repo.ius.io/ius-release-el7.rpm

View File

@@ -28,11 +28,19 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
"golang.org/x/crypto/ssh"
) )
var errCapabilityCompletion = errors.New("completion")
type Capability string
const CapabilityCompletion = Capability("completion")
type registryOptions struct { type registryOptions struct {
Insecure bool Insecure bool
Username string Username string
@@ -48,16 +56,50 @@ type Model struct {
ParentModel string ParentModel string
AdapterPaths []string AdapterPaths []string
ProjectorPaths []string ProjectorPaths []string
Template string
System string System string
License []string License []string
Digest string Digest string
Options map[string]interface{} Options map[string]interface{}
Messages []Message Messages []Message
Template *template.Template
} }
func (m *Model) IsEmbedding() bool { // CheckCapabilities checks if the model has the specified capabilities returning an error describing
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") // any missing or unknown capabilities
func (m *Model) CheckCapabilities(caps ...Capability) error {
var errs []error
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
f, err := os.Open(m.ModelPath)
if err != nil {
slog.Error("couldn't open model file", "error", err)
continue
}
defer f.Close()
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
ggml, _, err := llm.DecodeGGML(f, 0)
if err != nil {
slog.Error("couldn't decode ggml", "error", err)
continue
}
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
errs = append(errs, errCapabilityCompletion)
}
default:
slog.Error("unknown capability", "capability", cap)
return fmt.Errorf("unknown capability: %s", cap)
}
}
if err := errors.Join(errs...); err != nil {
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
}
return nil
} }
func (m *Model) String() string { func (m *Model) String() string {
@@ -82,10 +124,10 @@ func (m *Model) String() string {
}) })
} }
if m.Template != "" { if m.Template != nil {
modelfile.Commands = append(modelfile.Commands, parser.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "template", Name: "template",
Args: m.Template, Args: m.Template.String(),
}) })
} }
@@ -135,13 +177,6 @@ type Message struct {
Content string `json:"content"` Content string `json:"content"`
} }
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
}
type ConfigV2 struct { type ConfigV2 struct {
ModelFormat string `json:"model_format"` ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"` ModelFamily string `json:"model_family"`
@@ -160,7 +195,7 @@ type RootFS struct {
DiffIDs []string `json:"diff_ids"` DiffIDs []string `json:"diff_ids"`
} }
func GetManifest(mp ModelPath) (*ManifestV2, string, error) { func GetManifest(mp ModelPath) (*Manifest, string, error) {
fp, err := mp.GetManifestPath() fp, err := mp.GetManifestPath()
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@@ -170,7 +205,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
return nil, "", err return nil, "", err
} }
var manifest *ManifestV2 var manifest *Manifest
bts, err := os.ReadFile(fp) bts, err := os.ReadFile(fp)
if err != nil { if err != nil {
@@ -198,8 +233,7 @@ func GetModel(name string) (*Model, error) {
Name: mp.GetFullTagname(), Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(), ShortName: mp.GetShortTagname(),
Digest: digest, Digest: digest,
Template: "{{ .Prompt }}", Template: template.DefaultTemplate,
License: []string{},
} }
filename, err := GetBlobsPath(manifest.Config.Digest) filename, err := GetBlobsPath(manifest.Config.Digest)
@@ -235,13 +269,17 @@ func GetModel(name string) (*Model, error) {
model.AdapterPaths = append(model.AdapterPaths, filename) model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector": case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename) model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.template": case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model.Template = string(bts) model.Template, err = template.Parse(string(bts))
if err != nil {
return nil, err
}
case "application/vnd.ollama.image.system": case "application/vnd.ollama.image.system":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
@@ -249,13 +287,6 @@ func GetModel(name string) (*Model, error) {
} }
model.System = string(bts) model.System = string(bts)
case "application/vnd.ollama.image.prompt":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.Template = string(bts)
case "application/vnd.ollama.image.params": case "application/vnd.ollama.image.params":
params, err := os.Open(filename) params, err := os.Open(filename)
if err != nil { if err != nil {
@@ -414,17 +445,22 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err return err
} }
layers, err := parseFromFile(ctx, temp, "", fn) layer, err := NewLayer(temp, baseLayer.MediaType)
if err != nil { if err != nil {
return err return err
} }
if len(layers) != 1 { if _, err := temp.Seek(0, io.SeekStart); err != nil {
return errors.New("quantization failed") return err
} }
baseLayer.Layer = layers[0].Layer ggml, _, err := llm.DecodeGGML(temp, 0)
baseLayer.GGML = layers[0].GGML if err != nil {
return err
}
baseLayer.Layer = layer
baseLayer.GGML = ggml
} }
} }
@@ -817,7 +853,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
var manifest *ManifestV2 var manifest *Manifest
var err error var err error
var noprune string var noprune string
@@ -924,7 +960,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil return nil
} }
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) { func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header) headers := make(http.Header)
@@ -935,7 +971,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
} }
defer resp.Body.Close() defer resp.Body.Close()
var m *ManifestV2 var m *Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err return nil, err
} }
@@ -1029,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if anonymous { if anonymous {
// no user is associated with the public key, and the request requires non-anonymous access // no user is associated with the public key, and the request requires non-anonymous access
pubKey, nestedErr := auth.GetPublicKey() pubKey, nestedErr := auth.GetPublicKey()
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
if nestedErr != nil { if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr)) slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, errUnauthorized return nil, errUnauthorized
} }
return nil, &errtypes.UnknownOllamaKey{Key: pubKey} return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
} }
// user is associated with the public key, but is not authorized to make the request // user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized return nil, errUnauthorized

View File

@@ -14,7 +14,10 @@ import (
) )
type Manifest struct { type Manifest struct {
ManifestV2 SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
filepath string filepath string
fi os.FileInfo fi os.FileInfo
@@ -66,7 +69,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
p := filepath.Join(manifests, n.Filepath()) p := filepath.Join(manifests, n.Filepath())
var m ManifestV2 var m Manifest
f, err := os.Open(p) f, err := os.Open(p)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -83,12 +86,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err return nil, err
} }
return &Manifest{ m.filepath = p
ManifestV2: m, m.fi = fi
filepath: p, m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), return &m, nil
}, nil
} }
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
@@ -108,7 +110,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
} }
defer f.Close() defer f.Close()
m := ManifestV2{ m := Manifest{
SchemaVersion: 2, SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json", MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config, Config: config,

View File

@@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) {
} }
defer f.Close() defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil { if err := json.NewEncoder(f).Encode(Manifest{}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@@ -15,7 +15,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert" "github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/templates" "github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@@ -63,7 +63,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
} }
defer blob.Close() defer blob.Close()
ggml, _, err := llm.DecodeGGML(blob) ggml, _, err := llm.DecodeGGML(blob, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -77,62 +77,79 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
return layers, nil return layers, nil
} }
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error {
stat, err := file.Stat() stat, err := file.Stat()
if err != nil { if err != nil {
return nil, err return err
} }
r, err := zip.NewReader(file, stat.Size()) r, err := zip.NewReader(file, stat.Size())
if err != nil { if err != nil {
return nil, err return err
} }
tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
if err != nil {
return nil, err
}
defer os.RemoveAll(tempdir)
fn(api.ProgressResponse{Status: "unpacking model metadata"}) fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File { for _, f := range r.File {
if !filepath.IsLocal(f.Name) {
return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name)
}
n := filepath.Join(p, f.Name)
if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
return err
}
// TODO(mxyng): this should not write out all files to disk // TODO(mxyng): this should not write out all files to disk
outfile, err := os.Create(filepath.Join(tempdir, f.Name)) outfile, err := os.Create(n)
if err != nil { if err != nil {
return nil, err return err
} }
defer outfile.Close() defer outfile.Close()
infile, err := f.Open() infile, err := f.Open()
if err != nil { if err != nil {
return nil, err return err
} }
defer infile.Close() defer infile.Close()
if _, err = io.Copy(outfile, infile); err != nil { if _, err = io.Copy(outfile, infile); err != nil {
return nil, err return err
} }
if err := outfile.Close(); err != nil { if err := outfile.Close(); err != nil {
return nil, err return err
} }
if err := infile.Close(); err != nil { if err := infile.Close(); err != nil {
return nil, err return err
} }
} }
mf, err := convert.GetModelFormat(tempdir) return nil
}
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
tempDir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
if err != nil {
return nil, err
}
defer os.RemoveAll(tempDir)
if err := extractFromZipFile(tempDir, file, fn); err != nil {
return nil, err
}
mf, err := convert.GetModelFormat(tempDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
params, err := mf.GetParams(tempdir) params, err := mf.GetParams(tempDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mArch, err := mf.GetModelArch("", tempdir, params) mArch, err := mf.GetModelArch("", tempDir, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -150,7 +167,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
// TODO(mxyng): this should write directly into a layer // TODO(mxyng): this should write directly into a layer
// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model") // e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
temp, err := os.CreateTemp(tempdir, "fp16") temp, err := os.CreateTemp(tempDir, "fp16")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -176,7 +193,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
} }
defer bin.Close() defer bin.Close()
ggml, _, err := llm.DecodeGGML(bin) ggml, _, err := llm.DecodeGGML(bin, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -210,7 +227,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
var offset int64 var offset int64
for offset < stat.Size() { for offset < stat.Size() {
ggml, n, err := llm.DecodeGGML(file) ggml, n, err := llm.DecodeGGML(file, 0)
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {
@@ -239,7 +256,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) { func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers { for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" { if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := templates.NamedTemplate(s); err != nil { if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err) slog.Debug("template detection", "error", err)
} else { } else {
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")

112
server/model_test.go Normal file
View File

@@ -0,0 +1,112 @@
package server
import (
"archive/zip"
"bytes"
"errors"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func createZipFile(t *testing.T, name string) *os.File {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatal(err)
}
zf := zip.NewWriter(f)
defer zf.Close()
zh, err := zf.CreateHeader(&zip.FileHeader{Name: name})
if err != nil {
t.Fatal(err)
}
if _, err := io.Copy(zh, bytes.NewReader([]byte(""))); err != nil {
t.Fatal(err)
}
return f
}
func TestExtractFromZipFile(t *testing.T) {
cases := []struct {
name string
expect []string
err error
}{
{
name: "good",
expect: []string{"good"},
},
{
name: strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
expect: []string{filepath.Join("to", "good")},
},
{
name: strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
expect: []string{"good"},
},
{
name: strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
expect: []string{"good"},
},
{
name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
err: zip.ErrInsecurePath,
},
{
name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
err: zip.ErrInsecurePath,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
f := createZipFile(t, tt.name)
defer f.Close()
tempDir := t.TempDir()
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
t.Fatal(err)
}
var matches []string
if err := filepath.Walk(tempDir, func(p string, fi os.FileInfo, err error) error {
if err != nil {
return err
}
if !fi.IsDir() {
matches = append(matches, p)
}
return nil
}); err != nil {
t.Fatal(err)
}
var actual []string
for _, match := range matches {
rel, err := filepath.Rel(tempDir, match)
if err != nil {
t.Error(err)
}
actual = append(actual, rel)
}
if !slices.Equal(actual, tt.expect) {
t.Fatalf("expected %d files, got %d", len(tt.expect), len(matches))
}
})
}
}

View File

@@ -103,18 +103,9 @@ func (mp ModelPath) GetShortTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag) return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
} }
// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
// The models directory is where Ollama stores its model files and manifests.
func modelsDir() (string, error) {
return envconfig.ModelsDir, nil
}
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist. // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) { func (mp ModelPath) GetManifestPath() (string, error) {
dir, err := modelsDir() dir := envconfig.ModelsDir
if err != nil {
return "", err
}
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
} }
@@ -127,10 +118,7 @@ func (mp ModelPath) BaseURL() *url.URL {
} }
func GetManifestPath() (string, error) { func GetManifestPath() (string, error) {
dir, err := modelsDir() dir := envconfig.ModelsDir
if err != nil {
return "", err
}
path := filepath.Join(dir, "manifests") path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil { if err := os.MkdirAll(path, 0o755); err != nil {
@@ -141,10 +129,7 @@ func GetManifestPath() (string, error) {
} }
func GetBlobsPath(digest string) (string, error) { func GetBlobsPath(digest string) (string, error) {
dir, err := modelsDir() dir := envconfig.ModelsDir
if err != nil {
return "", err
}
// only accept actual sha256 digests // only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$" pattern := "^sha256[:-][0-9a-fA-F]{64}$"

View File

@@ -1,221 +1,83 @@
package server package server
import ( import (
"fmt" "bytes"
"context"
"log/slog" "log/slog"
"strings" "slices"
"text/template"
"text/template/parse"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template"
) )
// isResponseNode checks if the node contains .Response type tokenizeFunc func(context.Context, string) ([]int, error)
func isResponseNode(node *parse.ActionNode) bool {
for _, cmd := range node.Pipe.Cmds { // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
for _, arg := range cmd.Args { // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 { // latest message and 2) system messages
if fieldNode.Ident[0] == "Response" { func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
return true // pull out any system messages which should always be included in the prompt
} var system []api.Message
} msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
system = append(system, m)
return true
} }
}
return false
}
// formatTemplateForResponse formats the template AST to: return false
// 1. remove all nodes after the first .Response (if generate=true) })
// 2. add a .Response node to the end if it doesn't exist
// TODO(jmorganca): this should recursively cut the template before the first .Response if len(system) == 0 && m.System != "" {
func formatTemplateForResponse(tmpl *template.Template, generate bool) { // add model system prompt since it wasn't provided
var found bool system = append(system, api.Message{Role: "system", Content: m.System})
for i, node := range tmpl.Tree.Root.Nodes { }
if actionNode, ok := node.(*parse.ActionNode); ok {
if isResponseNode(actionNode) { // always include the last message
found = true n := len(msgs) - 1
if generate { // in reverse, find all messages that fit into context window
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1] for i := n - 1; i >= 0; i-- {
break var b bytes.Buffer
} if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
} return "", nil, err
} }
}
if !found { s, err := tokenize(ctx, b.String())
// add the response node if it doesn't exist
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
}
}
// Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
if err != nil {
return "", err
}
formatTemplateForResponse(parsed, generate)
vars := map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
}
var sb strings.Builder
if err := parsed.Execute(&sb, vars); err != nil {
return "", err
}
return sb.String(), nil
}
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
rendered, err := Prompt(tmpl, system, prompt, response, false)
if err != nil {
return 0, err
}
tokens, err := encode(rendered)
if err != nil {
slog.Error("failed to encode prompt", "err", err)
return 0, err
}
return len(tokens), err
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct {
System string
Prompt string
Response string
images []int
tokens int
}
var p prompt
// iterate through messages to build up {system,user,response} prompts
var imgId int
var prompts []prompt
for _, msg := range messages {
switch strings.ToLower(msg.Role) {
case "system":
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.System = msg.Content
case "user":
if p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
var sb strings.Builder
for range msg.Images {
fmt.Fprintf(&sb, "[img-%d] ", imgId)
p.images = append(p.images, imgId)
imgId += 1
}
sb.WriteString(msg.Content)
p.Prompt = sb.String()
case "assistant":
if p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.Response = msg.Content
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// add final prompt
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
}
// calculate token lengths for each prompt, estimating 768 tokens per images
for i, p := range prompts {
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
if err != nil { if err != nil {
return "", err return "", nil, err
} }
prompts[i].tokens = tokens + len(prompts[i].images)*768 c := len(s)
} if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// truncate images and prompts starting from the beginning of the list // images are represented as 768 sized embeddings
// until either one prompt remains or the total tokens fits the context window // TODO: get embedding length from project metadata
// TODO (jmorganca): this doesn't account for the context window room required for the response c += 768 * len(m.Images)
for { }
var required int
for _, p := range prompts {
required += p.tokens
} }
required += 1 // for bos token if c > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
if required <= window {
slog.Debug("prompt now fits in context window", "required", required, "window", window)
break break
} else {
n = i
} }
prompt := &prompts[0]
if len(prompt.images) > 1 {
img := prompt.images[0]
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
prompt.images = prompt.images[1:]
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
prompt.tokens -= 768
continue
}
if len(prompts) > 1 {
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
system := prompt.System
prompts = prompts[1:]
if system != "" && prompts[0].System == "" {
prompts[0].System = system
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
if err != nil {
return "", err
}
prompts[0].tokens = tokens + len(prompts[0].images)*768
}
continue
}
// stop truncating if there's only one prompt left
break
} }
var sb strings.Builder // truncate any messages that do not fit into the context window
for i, p := range prompts { var b bytes.Buffer
// last prompt should leave the response unrendered (for completion) if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1) return "", nil, err
if err != nil {
return "", err
}
sb.WriteString(rendered)
} }
return sb.String(), nil for _, m := range msgs[n:] {
for _, i := range m.Images {
images = append(images, llm.ImageData{
ID: len(images),
Data: i,
})
}
}
return b.String(), images, nil
} }

View File

@@ -1,204 +1,204 @@
package server package server
import ( import (
"bytes"
"context"
"strings" "strings"
"testing" "testing"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
func TestPrompt(t *testing.T) { func tokenize(_ context.Context, s string) (tokens []int, err error) {
tests := []struct { for range strings.Fields(s) {
name string tokens = append(tokens, len(tokens))
template string
system string
prompt string
response string
generate bool
want string
}{
{
name: "simple prompt",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "implicit response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
},
{
name: "response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
{
name: "cut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
generate: true,
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
},
{
name: "nocut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
},
} }
for _, tc := range tests { return
t.Run(tc.name, func(t *testing.T) {
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
if err != nil {
t.Errorf("error = %v", err)
}
if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want)
}
})
}
} }
func TestChatPrompt(t *testing.T) { func TestChatPrompt(t *testing.T) {
tests := []struct { type expect struct {
name string prompt string
template string images [][]byte
messages []api.Message }
window int
want string cases := []struct {
name string
limit int
msgs []api.Message
expect
}{ }{
{ {
name: "simple prompt", name: "messages",
template: "[INST] {{ .Prompt }} [/INST]", limit: 64,
messages: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "Hello"}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
}, },
window: 1024,
want: "[INST] Hello [/INST]",
}, },
{ {
name: "with system message", name: "truncate messages",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]", limit: 1,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "A test. And a thumping good one at that, I'd wager. ",
}, },
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
}, },
{ {
name: "with response", name: "truncate messages with image",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}", limit: 64,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "assistant", Content: "I am?"}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
},
expect: expect{
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
},
}, },
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
}, },
{ {
name: "with implicit response", name: "truncate messages with images",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]", limit: 64,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "assistant", Content: "I am?"}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
}, },
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
}, },
{ {
name: "with conversation", name: "messages with images",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", limit: 2048,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Content: "What are the potion ingredients?"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "assistant", Content: "sugar"}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "user", Content: "Anything else?"}, },
expect: expect{
prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
}, },
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
}, },
{ {
name: "with truncation", name: "message with image tag",
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ", limit: 2048,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "assistant", Content: "I am?"}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "user", Content: "Why is the sky blue?"}, },
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"}, expect: expect{
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
}, },
window: 10,
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
}, },
{ {
name: "images", name: "messages with interleaved images",
template: "{{ .System }} {{ .Prompt }}", limit: 2048,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}}, {Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
}, },
window: 1024,
want: "You are a Wizard. [img-0] Hello",
}, },
{ {
name: "images truncated", name: "truncate message with interleaved images",
template: "{{ .System }} {{ .Prompt }}", limit: 1024,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}}, {Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
}, },
window: 1024,
want: "You are a Wizard. [img-0] [img-1] Hello",
}, },
{ {
name: "empty list", name: "message with system prompt",
template: "{{ .System }} {{ .Prompt }}", limit: 2048,
messages: []api.Message{}, msgs: []api.Message{
window: 1024, {Role: "system", Content: "You are the Test Who Lived."},
want: "", {Role: "user", Content: "You're a test, Harry!"},
}, {Role: "assistant", Content: "I-I'm a what?"},
{ {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
name: "empty prompt", },
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", expect: expect{
messages: []api.Message{ prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
{Role: "user", Content: ""},
}, },
window: 1024,
want: "",
}, },
} }
encode := func(s string) ([]int, error) { tmpl, err := template.Parse(`
words := strings.Fields(s) {{- if .System }}{{ .System }} {{ end }}
return make([]int, len(words)), nil {{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
} }
for _, tc := range tests { for _, tt := range cases {
t.Run(tc.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode) model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
if err != nil { if err != nil {
t.Errorf("error = %v", err) t.Fatal(err)
} }
if got != tc.want { if tt.prompt != prompt {
t.Errorf("got: %q, want: %q", got, tc.want) t.Errorf("expected %q, got %q", tt.prompt, prompt)
}
if len(images) != len(tt.images) {
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
}
for i := range images {
if images[i].ID != i {
t.Errorf("expected ID %d, got %d", i, images[i].ID)
}
if !bytes.Equal(images[i].Data, tt.images[i]) {
t.Errorf("expected %q, got %q", tt.images[i], images[i])
}
} }
}) })
} }

View File

@@ -1,15 +1,16 @@
package server package server
import ( import (
"bytes"
"cmp" "cmp"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs" "log"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@@ -17,20 +18,22 @@ import (
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"slices" "slices"
"strconv"
"strings" "strings"
"syscall" "syscall"
"time" "time"
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/crypto/ssh"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@@ -55,7 +58,7 @@ func init() {
gin.SetMode(mode) gin.SetMode(mode)
} }
var defaultSessionDuration = 5 * time.Minute var errRequired = errors.New("is required")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
@@ -70,164 +73,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil return opts, nil
} }
func isSupportedImageType(image []byte) bool { // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
contentType := http.DetectContentType(image) // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"} func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
return slices.Contains(allowedTypes, contentType) if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
model, err := GetModel(name)
if err != nil {
return nil, nil, nil, err
}
if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
opts, err := modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err = <-errCh:
return nil, nil, nil, err
}
return runner.llama, model, &opts, nil
} }
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
// validate the request if req.Format != "" && req.Format != "json" {
switch { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
case len(req.Format) > 0 && req.Format != "json": } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return return
} }
for _, img := range req.Images { caps := []Capability{CapabilityCompletion}
if !isSupportedImageType(img) { r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) if errors.Is(err, errCapabilityCompletion) {
return c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
} return
} } else if err != nil {
handleScheduleError(c, req.Model, err)
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if model.IsEmbedding() { if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{ c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true, Done: true,
DoneReason: "load", DoneReason: "load",
}) })
return return
} }
checkpointLoaded := time.Now() images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
var prompt string prompt := req.Prompt
switch { if !req.Raw {
case req.Raw: var msgs []api.Message
prompt = req.Prompt if req.System != "" {
case req.Prompt != "": msgs = append(msgs, api.Message{Role: "system", Content: req.System})
if req.Template == "" { } else if m.System != "" {
req.Template = model.Template msgs = append(msgs, api.Message{Role: "system", Content: m.System})
} }
if req.System == "" { for _, i := range images {
req.System = model.System msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
} }
slog.Debug("generate handler", "prompt", req.Prompt) msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
var sb strings.Builder tmpl := m.Template
for i := range req.Images { if req.Template != "" {
fmt.Fprintf(&sb, "[img-%d] ", i) tmpl, err = template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} }
sb.WriteString(req.Prompt) var b bytes.Buffer
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.Reset()
if req.Context != nil { if req.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context) s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sb.WriteString(prev) b.WriteString(s)
} }
sb.WriteString(p) if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt = sb.String() prompt = b.String()
} }
slog.Debug("generate handler", "prompt", prompt) slog.Debug("generate request", "prompt", prompt, "images", images)
ch := make(chan any) ch := make(chan any)
var generated strings.Builder
go func() { go func() {
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
fn := func(r llm.CompletionResponse) { Prompt: prompt,
// Build up the full response Images: images,
if _, err := generated.WriteString(r.Content); err != nil { Format: req.Format,
ch <- gin.H{"error": err.Error()} Options: opts,
return }, func(r llm.CompletionResponse) {
} ch <- api.GenerateResponse{
resp := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content, Response: r.Content,
Done: r.Done,
DoneReason: r.DoneReason, DoneReason: r.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
@@ -236,156 +215,54 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
}, },
} }
}); err != nil {
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = append(req.Context, tokens...)
}
}
ch <- resp
}
var images []llm.ImageData
for i := range req.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
}
// Start prediction
req := llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response var r api.GenerateResponse
var final api.GenerateResponse
var sb strings.Builder var sb strings.Builder
for resp := range ch { for rr := range ch {
switch r := resp.(type) { switch t := rr.(type) {
case api.GenerateResponse: case api.GenerateResponse:
sb.WriteString(r.Response) sb.WriteString(t.Response)
final = r r = t
case gin.H: case gin.H:
if errorMsg, ok := r["error"].(string); ok { msg, ok := t["error"].(string)
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) if !ok {
return msg = "unexpected error format in response"
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default: default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return return
} }
} }
final.Response = sb.String() r.Response = sb.String()
c.JSON(http.StatusOK, final) c.JSON(http.StatusOK, r)
return return
} }
streamResponse(c, ch) streamResponse(c, ch)
} }
func getDefaultSessionDuration() time.Duration {
if envconfig.KeepAlive != "" {
v, err := strconv.Atoi(envconfig.KeepAlive)
if err != nil {
d, err := time.ParseDuration(envconfig.KeepAlive)
if err != nil {
return defaultSessionDuration
}
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
d := time.Duration(v) * time.Second
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
return defaultSessionDuration
}
func (s *Server) EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if req.Model == "" { r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
model, err := GetModel(req.Model)
if err != nil { if err != nil {
var pErr *fs.PathError handleScheduleError(c, req.Model, err)
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return return
} }
@@ -395,17 +272,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt) embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return return
} }
resp := api.EmbeddingResponse{ c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
Embedding: embedding,
}
c.JSON(http.StatusOK, resp)
} }
func (s *Server) PullModelHandler(c *gin.Context) { func (s *Server) PullModelHandler(c *gin.Context) {
@@ -680,12 +554,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
if req.Template != "" { if req.Template != "" {
m.Template = req.Template m.Template, err = template.Parse(req.Template)
if err != nil {
return nil, err
}
} }
msgs := make([]api.Message, 0) msgs := make([]api.Message, len(m.Messages))
for _, msg := range m.Messages { for i, msg := range m.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content}) msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
} }
n := model.ParseName(req.Model) n := model.ParseName(req.Model)
@@ -701,7 +578,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
resp := &api.ShowResponse{ resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"), License: strings.Join(m.License, "\n"),
System: m.System, System: m.System,
Template: m.Template, Template: m.Template.String(),
Details: modelDetails, Details: modelDetails,
Messages: msgs, Messages: msgs,
ModifiedAt: manifest.fi.ModTime(), ModifiedAt: manifest.fi.ModTime(),
@@ -754,7 +631,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
func getKVData(digest string, verbose bool) (llm.KV, error) { func getKVData(digest string, verbose bool) (llm.KV, error) {
kvData, err := llm.LoadModel(digest) maxArraySize := 0
if verbose {
maxArraySize = -1
}
kvData, err := llm.LoadModel(digest, maxArraySize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -893,7 +774,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
_, err = os.Stat(path) _, err = os.Stat(path)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
@@ -906,6 +786,12 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return return
} }
if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
c.Header("LocalLocation", path)
c.Status(http.StatusTemporaryRedirect)
return
}
layer, err := NewLayer(c.Request.Body, "") layer, err := NewLayer(c.Request.Body, "")
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -920,6 +806,54 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusCreated) c.Status(http.StatusCreated)
} }
func (s *Server) IsLocal(c *gin.Context) bool {
if authz := c.GetHeader("Authorization"); authz != "" {
parts := strings.Split(authz, ":")
if len(parts) != 3 {
return false
}
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
if err != nil {
return false
}
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
requestData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return false
}
partialRequestDataParts := strings.Split(string(requestData), ",")
if len(partialRequestDataParts) != 3 {
return false
}
signature, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return false
}
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
return false
}
serverPublicKey, err := auth.GetPublicKey()
if err != nil {
log.Fatal(err)
}
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
return true
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return false
}
return false
}
func isLocalIP(ip netip.Addr) bool { func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil { if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces { for _, iface := range interfaces {
@@ -1035,7 +969,10 @@ func (s *Server) GenerateRoutes() http.Handler {
r.GET("/api/ps", s.ProcessHandler) r.GET("/api/ps", s.ProcessHandler)
// Compatibility endpoints // Compatibility endpoints
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler) r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} { for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) { r.Handle(method, "/", func(c *gin.Context) {
@@ -1101,11 +1038,20 @@ func Serve(ln net.Listener) error {
schedCtx, schedDone := context.WithCancel(ctx) schedCtx, schedDone := context.WithCancel(ctx)
sched := InitScheduler(schedCtx) sched := InitScheduler(schedCtx)
s := &Server{addr: ln.Addr(), sched: sched} s := &Server{addr: ln.Addr(), sched: sched}
r := s.GenerateRoutes()
http.Handle("/", s.GenerateRoutes())
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version)) slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
srvr := &http.Server{ srvr := &http.Server{
Handler: r, // Use http.DefaultServeMux so we get net/http/pprof for
// free.
//
// TODO(bmizerany): Decide if we want to make this
// configurable so it is not exposed by default, or allow
// users to bind it to a different port. This was a quick
// and easy way to get pprof, but it may not be the best
// way.
Handler: nil,
} }
// listen for a ctrl+c and stop any loaded llm // listen for a ctrl+c and stop any loaded llm
@@ -1224,142 +1170,63 @@ func (s *Server) ProcessHandler(c *gin.Context) {
models = append(models, mr) models = append(models, mr)
} }
slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
// longest duration remaining listed first
return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
})
c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
} }
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return runner.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
if err != nil {
return "", err
}
return prompt, nil
}
func (s *Server) ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest var req api.ChatRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
// validate the request caps := []Capability{CapabilityCompletion}
switch { r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
case req.Model == "": if errors.Is(err, errCapabilityCompletion) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return return
case len(req.Format) > 0 && req.Format != "json": } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"}) handleScheduleError(c, req.Model, err)
return return
} }
model, err := GetModel(req.Model) if len(req.Messages) == 0 {
if err != nil { c.JSON(http.StatusOK, api.ChatResponse{
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
checkpointLoaded := time.Now()
// if the first message is not a system message, then add the model's default system message
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
req.Messages = append([]api.Message{
{
Role: "system",
Content: model.System,
},
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
Done: true, Done: true,
DoneReason: "load", DoneReason: "load",
Message: api.Message{Role: "assistant"}, })
}
c.JSON(http.StatusOK, resp)
return return
} }
// only send images that are in the prompt prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
var i int if err != nil {
var images []llm.ImageData c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
for _, m := range req.Messages { return
for _, img := range m.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
images = append(images, llm.ImageData{Data: img, ID: i})
}
i += 1
}
} }
slog.Debug("chat handler", "prompt", prompt, "images", len(images)) slog.Debug("chat request", "images", len(images), "prompt", prompt)
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
fn := func(r llm.CompletionResponse) { Prompt: prompt,
resp := api.ChatResponse{ Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
ch <- api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content}, Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1372,64 +1239,52 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
}, },
} }
}); err != nil {
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- resp
}
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response var r api.ChatResponse
var final api.ChatResponse
var sb strings.Builder var sb strings.Builder
for resp := range ch { for rr := range ch {
switch r := resp.(type) { switch t := rr.(type) {
case api.ChatResponse: case api.ChatResponse:
sb.WriteString(r.Message.Content) sb.WriteString(t.Message.Content)
final = r r = t
case gin.H: case gin.H:
if errorMsg, ok := r["error"].(string); ok { msg, ok := t["error"].(string)
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) if !ok {
return msg = "unexpected error format in response"
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default: default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return return
} }
} }
final.Message = api.Message{Role: "assistant", Content: sb.String()} r.Message.Content = sb.String()
c.JSON(http.StatusOK, final) c.JSON(http.StatusOK, r)
return return
} }
streamResponse(c, ch) streamResponse(c, ch)
} }
func handleErrorResponse(c *gin.Context, err error) { func handleScheduleError(c *gin.Context, name string, err error) {
if errors.Is(err, context.Canceled) { switch {
case errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"}) c.JSON(499, gin.H{"error": "request canceled"})
return case errors.Is(err, ErrMaxQueue):
}
if errors.Is(err, ErrMaxQueue) {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()}) c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
return case errors.Is(err, os.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }

View File

@@ -545,9 +545,9 @@ func TestCreateDetectTemplate(t *testing.T) {
} }
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"), filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
filepath.Join(p, "blobs", "sha256-9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b7"),
filepath.Join(p, "blobs", "sha256-b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce57cdac893db8139"),
}) })
}) })

View File

@@ -20,6 +20,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@@ -105,6 +106,24 @@ func Test_Routes(t *testing.T) {
assert.Empty(t, len(modelList.Models)) assert.Empty(t, len(modelList.Models))
}, },
}, },
{
Name: "openai empty list",
Method: http.MethodGet,
Path: "/v1/models",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var modelList openai.ListCompletion
err = json.Unmarshal(body, &modelList)
require.NoError(t, err)
assert.Equal(t, "list", modelList.Object)
assert.Empty(t, modelList.Data)
},
},
{ {
Name: "Tags Handler (yes tags)", Name: "Tags Handler (yes tags)",
Method: http.MethodGet, Method: http.MethodGet,
@@ -128,6 +147,25 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "test-model:latest", modelList.Models[0].Name) assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
}, },
}, },
{
Name: "openai list models with tags",
Method: http.MethodGet,
Path: "/v1/models",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var modelList openai.ListCompletion
err = json.Unmarshal(body, &modelList)
require.NoError(t, err)
assert.Len(t, modelList.Data, 1)
assert.Equal(t, "test-model:latest", modelList.Data[0].Id)
assert.Equal(t, "library", modelList.Data[0].OwnedBy)
},
},
{ {
Name: "Create Model Handler", Name: "Create Model Handler",
Method: http.MethodPost, Method: http.MethodPost,
@@ -216,6 +254,24 @@ func Test_Routes(t *testing.T) {
assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0") assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
}, },
}, },
{
Name: "openai retrieve model handler",
Method: http.MethodGet,
Path: "/v1/models/show-model",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var retrieveResp api.RetrieveModelResponse
err = json.Unmarshal(body, &retrieveResp)
require.NoError(t, err)
assert.Equal(t, "show-model", retrieveResp.Id)
assert.Equal(t, "library", retrieveResp.OwnedBy)
},
},
} }
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())

View File

@@ -23,7 +23,8 @@ type LlmRequest struct {
ctx context.Context //nolint:containedctx ctx context.Context //nolint:containedctx
model *Model model *Model
opts api.Options opts api.Options
sessionDuration time.Duration origNumCtx int // Track the initial ctx request
sessionDuration *api.Duration
successCh chan *runnerRef successCh chan *runnerRef
errCh chan error errCh chan error
schedAttempts uint schedAttempts uint
@@ -38,13 +39,23 @@ type Scheduler struct {
loaded map[string]*runnerRef loaded map[string]*runnerRef
loadedMu sync.Mutex loadedMu sync.Mutex
loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int)
newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
getGpuFn func() gpu.GpuInfoList getGpuFn func() gpu.GpuInfoList
getCpuFn func() gpu.GpuInfoList getCpuFn func() gpu.GpuInfoList
reschedDelay time.Duration reschedDelay time.Duration
} }
// Default automatic value for number of models we allow per GPU
// Model will still need to fit in VRAM, but loading many small models
// on a large GPU can cause stalling
var defaultModelsPerGPU = 3
// Default automatic value for parallel setting
// Model will still need to fit in VRAM. If this setting wont fit
// we'll back off down to 1 to try to get it to fit
var defaultParallel = 4
var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded") var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
func InitScheduler(ctx context.Context) *Scheduler { func InitScheduler(ctx context.Context) *Scheduler {
@@ -64,14 +75,11 @@ func InitScheduler(ctx context.Context) *Scheduler {
} }
// context must be canceled to decrement ref count and release the runner // context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) { func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
// allocate a large enough kv cache for all parallel requests
if opts.NumCtx < 4 { if opts.NumCtx < 4 {
opts.NumCtx = 4 opts.NumCtx = 4
} }
opts.NumCtx *= envconfig.NumParallel
req := &LlmRequest{ req := &LlmRequest{
ctx: c, ctx: c,
model: model, model: model,
@@ -110,13 +118,28 @@ func (s *Scheduler) processPending(ctx context.Context) {
case pending := <-s.pendingReqCh: case pending := <-s.pendingReqCh:
// Block other requests until we get this pending request running // Block other requests until we get this pending request running
pending.schedAttempts++ pending.schedAttempts++
if pending.origNumCtx == 0 {
pending.origNumCtx = pending.opts.NumCtx
}
if pending.ctx.Err() != nil { if pending.ctx.Err() != nil {
slog.Debug("pending request cancelled or timed out, skipping scheduling") slog.Debug("pending request cancelled or timed out, skipping scheduling")
continue continue
} }
numParallel := envconfig.NumParallel
// TODO (jmorganca): multimodal models don't support parallel yet
// see https://github.com/ollama/ollama/issues/4165
if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
numParallel = 1
slog.Warn("multimodal models don't support parallel requests yet")
}
for { for {
cpus := s.getCpuFn()
var systemMem gpu.GpuInfo
if len(cpus) > 0 {
systemMem = cpus[0]
}
var runnerToExpire *runnerRef var runnerToExpire *runnerRef
s.loadedMu.Lock() s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath] runner := s.loaded[pending.model.ModelPath]
@@ -143,35 +166,94 @@ func (s *Scheduler) processPending(ctx context.Context) {
gpus = s.getGpuFn() gpus = s.getGpuFn()
} }
if envconfig.MaxRunners <= 0 {
// No user specified MaxRunners, so figure out what automatic setting to use
// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
// if any GPU has unreliable free memory reporting, 1x the number of GPUs
allReliable := true
for _, gpu := range gpus {
if gpu.UnreliableFreeMemory {
allReliable = false
break
}
}
if allReliable {
envconfig.MaxRunners = defaultModelsPerGPU * len(gpus)
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus))
} else {
slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency")
envconfig.MaxRunners = len(gpus)
}
}
// Load model for fitting // Load model for fitting
ggml, err := llm.LoadModel(pending.model.ModelPath) ggml, err := llm.LoadModel(pending.model.ModelPath, 0)
if err != nil { if err != nil {
pending.errCh <- err pending.errCh <- err
break break
} }
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
maxSize := systemMem.FreeMemory
// Add available GPU memory to the total pool
// macOS hardware has unified memory so don't double count
if runtime.GOOS != "darwin" {
for _, gpu := range gpus {
if gpu.Library == "cpu" {
continue
}
if loadedCount == 0 {
// If no other models are loaded, set the limit based on what's available
maxSize += gpu.FreeMemory
} else {
// Other models could be unloaded, favor total memory for limit
maxSize += gpu.TotalMemory
}
}
}
// Block attempting to load a model larger than system memory + GPU memory
if estimate.TotalSize > maxSize {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
// Linux will crash if over-allocating memory - return an error to the user.
// TODO (jmorganca): add reasonable upper limits for darwin and windows as well
if runtime.GOOS == "linux" {
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
break
}
}
// Evaluate if the model will fit in the available system memory, or if we should unload a model first // Evaluate if the model will fit in the available system memory, or if we should unload a model first
if len(gpus) == 1 && gpus[0].Library == "cpu" { if len(gpus) == 1 && gpus[0].Library == "cpu" {
// simplifying assumption of defaultParallel when in CPU mode
if numParallel <= 0 {
numParallel = defaultParallel
}
pending.opts.NumCtx = pending.origNumCtx * numParallel
if loadedCount == 0 { if loadedCount == 0 {
slog.Debug("cpu mode with first model, loading") slog.Debug("cpu mode with first model, loading")
s.loadFn(pending, ggml, gpus) s.loadFn(pending, ggml, gpus, numParallel)
break break
} }
runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus) runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus)
if runnerToExpire == nil { if runnerToExpire == nil {
slog.Debug("cpu mode with available system memory or first model, loading") slog.Debug("cpu mode with available system memory or first model, loading")
s.loadFn(pending, ggml, gpus) s.loadFn(pending, ggml, gpus, numParallel)
break break
} }
// else we need to expire a runner // else we need to expire a runner
} else if loadedCount == 0 { } else if loadedCount == 0 {
// No models loaded. Load the model but prefer the best fit. // No models loaded. Load the model but prefer the best fit.
slog.Debug("loading first model", "model", pending.model.ModelPath) slog.Debug("loading first model", "model", pending.model.ModelPath)
g := pickBestFitGPUs(pending, ggml, gpus) g := pickBestFitGPUs(pending, ggml, gpus, &numParallel)
if g != nil { if g != nil {
gpus = g gpus = g
} }
s.loadFn(pending, ggml, gpus) s.loadFn(pending, ggml, gpus, numParallel)
break break
} }
@@ -186,10 +268,10 @@ func (s *Scheduler) processPending(ctx context.Context) {
// Update free memory from currently loaded models // Update free memory from currently loaded models
s.updateFreeSpace(availGpus) s.updateFreeSpace(availGpus)
fitGpus := pickBestFitGPUs(pending, ggml, availGpus) fitGpus := pickBestFitGPUs(pending, ggml, availGpus, &numParallel)
if fitGpus != nil { if fitGpus != nil {
slog.Debug("new model fits with existing models, loading") slog.Debug("new model fits with existing models, loading")
s.loadFn(pending, ggml, fitGpus) s.loadFn(pending, ggml, fitGpus, numParallel)
break break
} }
@@ -341,7 +423,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
runner.expireTimer.Stop() runner.expireTimer.Stop()
runner.expireTimer = nil runner.expireTimer = nil
} }
runner.sessionDuration = pending.sessionDuration if pending.sessionDuration != nil {
runner.sessionDuration = pending.sessionDuration.Duration
}
pending.successCh <- runner pending.successCh <- runner
go func() { go func() {
<-pending.ctx.Done() <-pending.ctx.Done()
@@ -350,8 +434,15 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
}() }()
} }
func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) { func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts) if numParallel < 1 {
numParallel = 1
}
sessionDuration := envconfig.KeepAlive
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil { if err != nil {
// some older models are not compatible with newer versions of llama.cpp // some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to // show a generalized compatibility error until there is a better way to
@@ -368,13 +459,14 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
llama: llama, llama: llama,
Options: &req.opts, Options: &req.opts,
sessionDuration: req.sessionDuration, sessionDuration: sessionDuration,
gpus: gpus, gpus: gpus,
estimatedVRAM: llama.EstimatedVRAM(), estimatedVRAM: llama.EstimatedVRAM(),
estimatedTotal: llama.EstimatedTotal(), estimatedTotal: llama.EstimatedTotal(),
loading: true, loading: true,
refCount: 1, refCount: 1,
} }
runner.numParallel = numParallel
runner.refMu.Lock() runner.refMu.Lock()
s.loadedMu.Lock() s.loadedMu.Lock()
@@ -483,8 +575,9 @@ type runnerRef struct {
expireTimer *time.Timer expireTimer *time.Timer
expiresAt time.Time expiresAt time.Time
model *Model model *Model
modelPath string modelPath string
numParallel int
*api.Options *api.Options
} }
@@ -525,6 +618,9 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
optsNew.NumGPU = -1 optsNew.NumGPU = -1
} }
// Normalize the NumCtx for parallelism
optsExisting.NumCtx = optsExisting.NumCtx / runner.numParallel
ctx, cancel := context.WithTimeout(ctx, timeout) ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel() defer cancel()
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed? if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
@@ -611,22 +707,38 @@ func (a ByDuration) Less(i, j int) bool {
// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits // pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
// If the model can not be fit fully within the available GPU(s) nil is returned // If the model can not be fit fully within the available GPU(s) nil is returned
func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.GpuInfoList { // If numParallel is <= 0, this will attempt try to optimize parallism based on available VRAM, and adjust
// opts.NumCtx accordingly
func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel *int) gpu.GpuInfoList {
var estimatedVRAM uint64 var estimatedVRAM uint64
var numParallelToTry []int
if *numParallel <= 0 {
// If no specific parallel setting was provided, try larger then smaller, always end with 1
numParallelToTry = append(numParallelToTry, defaultParallel, 1)
} else {
numParallelToTry = []int{*numParallel}
}
for _, gl := range gpus.ByLibrary() { for _, gl := range gpus.ByLibrary() {
var ok bool var ok bool
sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...) sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
// TODO - potentially sort by performance capability, existing models loaded, etc. // TODO - potentially sort by performance capability, existing models loaded, etc.
// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups // Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl))) sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
// First attempt to fit the model into a single GPU // First attempt to fit the model into a single GPU
if !envconfig.SchedSpread { for _, p := range numParallelToTry {
for _, g := range sgl { req.opts.NumCtx = req.origNumCtx * p
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { if !envconfig.SchedSpread {
slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM)) for _, g := range sgl {
return []gpu.GpuInfo{g} if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
*numParallel = p
return []gpu.GpuInfo{g}
}
} }
} }
} }
@@ -636,9 +748,13 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.
// - try subsets of GPUs instead of just falling back to 1 or all in a family // - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUs // Now try all the GPUs
if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { for _, p := range numParallelToTry {
slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "required", format.HumanBytes2(estimatedVRAM)) req.opts.NumCtx = req.origNumCtx * p
return sgl if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
*numParallel = p
return sgl
}
} }
} }
return nil return nil

View File

@@ -44,14 +44,14 @@ func TestLoad(t *testing.T) {
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
sessionDuration: 2, sessionDuration: &api.Duration{Duration: 2 * time.Second},
} }
// Fail to load model first // Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return nil, fmt.Errorf("something failed to load model blah") return nil, fmt.Errorf("something failed to load model blah")
} }
gpus := gpu.GpuInfoList{} gpus := gpu.GpuInfoList{}
s.load(req, ggml, gpus) s.load(req, ggml, gpus, 0)
require.Empty(t, req.successCh) require.Empty(t, req.successCh)
require.Len(t, req.errCh, 1) require.Len(t, req.errCh, 1)
s.loadedMu.Lock() s.loadedMu.Lock()
@@ -61,10 +61,10 @@ func TestLoad(t *testing.T) {
require.Contains(t, err.Error(), "this model may be incompatible") require.Contains(t, err.Error(), "this model may be incompatible")
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}} server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return server, nil return server, nil
} }
s.load(req, ggml, gpus) s.load(req, ggml, gpus, 0)
select { select {
case err := <-req.errCh: case err := <-req.errCh:
require.NoError(t, err) require.NoError(t, err)
@@ -78,12 +78,12 @@ func TestLoad(t *testing.T) {
req.model.ModelPath = "dummy_model_path" req.model.ModelPath = "dummy_model_path"
server.waitResp = fmt.Errorf("wait failure") server.waitResp = fmt.Errorf("wait failure")
s.load(req, ggml, gpus) s.load(req, ggml, gpus, 0)
select { select {
case err := <-req.errCh: case err := <-req.errCh:
require.Contains(t, err.Error(), "wait failure") require.Contains(t, err.Error(), "wait failure")
case resp := <-req.successCh: case resp := <-req.successCh:
t.Errorf("unexpected success %v", resp) t.Fatalf("unexpected success %v", resp)
} }
s.loadedMu.Lock() s.loadedMu.Lock()
runner := s.loaded["dummy_model_path"] runner := s.loaded["dummy_model_path"]
@@ -102,7 +102,7 @@ type bundle struct {
ggml *llm.GGML ggml *llm.GGML
} }
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return scenario.srv, nil return scenario.srv, nil
} }
@@ -128,21 +128,21 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
"tokenizer.ggml.scores": []float32{0}, "tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0}, "tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{ }, []llm.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
}) })
require.NoError(t, err) require.NoError(t, err)
fname := f.Name() fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname} model := &Model{Name: modelName, ModelPath: fname}
scenario.ggml, err = llm.LoadModel(model.ModelPath) scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
require.NoError(t, err) require.NoError(t, err)
scenario.req = &LlmRequest{ scenario.req = &LlmRequest{
ctx: scenario.ctx, ctx: scenario.ctx,
model: model, model: model,
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
sessionDuration: 5 * time.Millisecond, sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
} }
@@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
// Same model, same request // Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10) scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 5 * time.Millisecond scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1b := newScenario(t, ctx, "ollama-model-1", 11) scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = 0 scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
// simple reload of same model // simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20) scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = 5 * time.Millisecond scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
// Multiple loaded models // Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte) scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
@@ -199,8 +199,10 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1a.req.errCh) require.Empty(t, scenario1a.req.errCh)
case err := <-scenario1a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Fatal("timeout")
} }
// Same runner as first request due to not needing a reload // Same runner as first request due to not needing a reload
@@ -212,8 +214,10 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1b.req.errCh) require.Empty(t, scenario1b.req.errCh)
case err := <-scenario1b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Fatal("timeout")
} }
// Trigger a reload // Trigger a reload
@@ -230,8 +234,10 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario2a.srv) require.Equal(t, resp.llama, scenario2a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario2a.req.errCh) require.Empty(t, scenario2a.req.errCh)
case err := <-scenario2a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Fatal("timeout")
} }
envconfig.MaxRunners = 1 envconfig.MaxRunners = 1
@@ -246,8 +252,10 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3a.srv) require.Equal(t, resp.llama, scenario3a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3a.req.errCh) require.Empty(t, scenario3a.req.errCh)
case err := <-scenario3a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Fatal("timeout")
} }
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 1) require.Len(t, s.loaded, 1)
@@ -262,8 +270,10 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3b.srv) require.Equal(t, resp.llama, scenario3b.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3b.req.errCh) require.Empty(t, scenario3b.req.errCh)
case err := <-scenario3b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Fatal("timeout")
} }
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 2) require.Len(t, s.loaded, 2)
@@ -278,8 +288,10 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3c.srv) require.Equal(t, resp.llama, scenario3c.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3c.req.errCh) require.Empty(t, scenario3c.req.errCh)
case err := <-scenario3c.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Fatal("timeout")
} }
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 3) require.Len(t, s.loaded, 3)
@@ -306,7 +318,7 @@ func TestRequests(t *testing.T) {
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3d.req.errCh) require.Empty(t, scenario3d.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Fatal("timeout")
} }
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 2) require.Len(t, s.loaded, 2)
@@ -318,11 +330,11 @@ func TestGetRunner(t *testing.T) {
defer done() defer done()
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = 0 scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = 0 scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0 scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
envconfig.MaxQueuedRequests = 1 envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = func() gpu.GpuInfoList {
@@ -349,7 +361,7 @@ func TestGetRunner(t *testing.T) {
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, errCh1a) require.Empty(t, errCh1a)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Fatal("timeout")
} }
scenario1a.ctxDone() scenario1a.ctxDone()
s.loadedMu.Lock() s.loadedMu.Lock()
@@ -400,9 +412,9 @@ func TestPrematureExpired(t *testing.T) {
slog.Info("sending premature expired event now") slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Fatal("timeout")
} }
time.Sleep(scenario1a.req.sessionDuration) time.Sleep(scenario1a.req.sessionDuration.Duration)
scenario1a.ctxDone() scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1) require.LessOrEqual(t, len(s.finishedReqCh), 1)
@@ -423,11 +435,11 @@ func TestUseLoadedRunner(t *testing.T) {
ctx: ctx, ctx: ctx,
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
sessionDuration: 2, sessionDuration: &api.Duration{Duration: 2},
} }
finished := make(chan *LlmRequest) finished := make(chan *LlmRequest)
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
r1 := &runnerRef{llama: llm1, sessionDuration: 1} r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1}
req.useLoadedRunner(r1, finished) req.useLoadedRunner(r1, finished)
require.Equal(t, uint(1), r1.refCount) require.Equal(t, uint(1), r1.refCount)
require.Equal(t, time.Duration(2), r1.sessionDuration) require.Equal(t, time.Duration(2), r1.sessionDuration)
@@ -435,7 +447,7 @@ func TestUseLoadedRunner(t *testing.T) {
case success := <-req.successCh: case success := <-req.successCh:
require.Equal(t, r1, success) require.Equal(t, r1, success)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Fatal("timeout")
} }
done() done()
fin := <-finished fin := <-finished
@@ -461,8 +473,8 @@ func TestUpdateFreeSpace(t *testing.T) {
gpus[1].FreeMemory = 1900 gpus[1].FreeMemory = 1900
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}} llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}}
llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}} llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}}
r1 := &runnerRef{llama: llm1, gpus: gpus} r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1}
r2 := &runnerRef{llama: llm2, gpus: gpus} r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1}
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.loadedMu.Lock() s.loadedMu.Lock()
@@ -513,8 +525,8 @@ func TestFindRunnerToUnload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done() defer done()
r1 := &runnerRef{refCount: 1, sessionDuration: 1} r1 := &runnerRef{refCount: 1, sessionDuration: 1, numParallel: 1}
r2 := &runnerRef{sessionDuration: 2} r2 := &runnerRef{sessionDuration: 2, numParallel: 1}
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.loadedMu.Lock() s.loadedMu.Lock()
@@ -536,9 +548,13 @@ func TestNeedsReload(t *testing.T) {
llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
do := api.DefaultOptions() do := api.DefaultOptions()
runner := &runnerRef{ runner := &runnerRef{
model: &Model{AdapterPaths: []string{"adapter1"}, ProjectorPaths: []string{"projector1"}}, model: &Model{
Options: &do, AdapterPaths: []string{"adapter1"},
llama: llm, ProjectorPaths: []string{"projector1"},
},
Options: &do,
llama: llm,
numParallel: 1,
} }
req := &LlmRequest{ req := &LlmRequest{
model: &Model{ model: &Model{
@@ -581,8 +597,8 @@ func TestUnloadAllRunners(t *testing.T) {
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.unloadAllRunners() s.unloadAllRunners()
r1 := &runnerRef{llama: llm1} r1 := &runnerRef{llama: llm1, numParallel: 1}
r2 := &runnerRef{llama: llm2} r2 := &runnerRef{llama: llm2, numParallel: 1}
s.loadedMu.Lock() s.loadedMu.Lock()
s.loaded["a"] = r1 s.loaded["a"] = r1
@@ -596,14 +612,32 @@ func TestUnloadAllRunners(t *testing.T) {
func TestUnload(t *testing.T) { func TestUnload(t *testing.T) {
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
r1 := &runnerRef{llama: llm1} r1 := &runnerRef{llama: llm1, numParallel: 1}
r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}} r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1}
r1.unload() r1.unload()
require.True(t, llm1.closeCalled) require.True(t, llm1.closeCalled)
r2.unload() r2.unload()
require.Nil(t, r2.model) require.Nil(t, r2.model)
} }
func TestAlreadyCanceled(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
dctx, done2 := context.WithCancel(ctx)
done2()
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
s := InitScheduler(ctx)
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
time.Sleep(5 * time.Millisecond)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1a.req.errCh)
require.Empty(t, scenario1a.req.successCh)
}
type mockLlm struct { type mockLlm struct {
pingResp error pingResp error
waitResp error waitResp error

8
template/alfred.gotmpl Normal file
View File

@@ -0,0 +1,8 @@
{{- if .Messages }}
{{- if .System }}<start_system>{{ .System }}<end_message>
{{- end }}
{{- range .Messages }}<start_{{ .Role }}>{{ .Content }}<end_message>
{{- end }}<start_assistant>
{{- else }}
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
{{- end }}

19
template/alpaca.gotmpl Normal file
View File

@@ -0,0 +1,19 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{- end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- else if eq .Role "assistant" }}### Response:
{{- end }}
{{ .Content }}
{{ end }}### Response:
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
{{ .Prompt }}
{{ end }}### Response:
{{ .Response }}
{{- end }}

15
template/chatml.gotmpl Normal file
View File

@@ -0,0 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else }}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end }}

17
template/chatqa.gotmpl Normal file
View File

@@ -0,0 +1,17 @@
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{- else if eq .Role "assistant" }}Assistant:
{{- end }} {{ .Content }}
{{ end }}Assistant:
{{- else }}
{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
{{- end }}

View File

@@ -0,0 +1,19 @@
{{- if .Messages }}
{{- if .System }}Source: system
{{ .System }} <step> {{ end }}
{{- range .Messages }}Source: {{ .Role }}
{{ .Content }} <step> {{ end }}Source: assistant
Destination: user
{{ else }}
{{ if .System }} Source: system
{{ .System }} <step>{{ end }} Source: user
{{ .Prompt }} <step> Source: assistant
Destination: user
{{ .Response }}<step>
{{- end }}

View File

@@ -0,0 +1,13 @@
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{ else if eq .Role "assistant" }}Falcon:
{{ end }}{{ .Content }}
{{ end }}Falcon:
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: {{ .Response }}
{{- end }}

View File

@@ -0,0 +1,16 @@
{{- if .Messages }}
{{- range $index, $_ := .Messages }}<start_of_turn>
{{- if eq .Role "user" }}user
{{- if and $.System (eq $index 0) }}
{{ $.System }}
{{- end }}
{{- else if eq .Role "assistant" }}model
{{- end }}
{{ .Content }}<end_of_turn>
{{ end }}<start_of_turn>model
{{ else }}
<start_of_turn>user
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ .Response }}<end_of_turn>
{{- end }}

View File

@@ -0,0 +1,23 @@
{{- if .Messages }}
{{- if .System }}System:
{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}Question:
{{- else if eq .Role "assistant" }}Answer:
{{- end }}
{{ .Content }}
{{ end }}Answer:
{{ else }}
{{ if .System }}
System:
{{ .System }}
{{ end }}{{ if .Prompt }}Question:
{{ .Prompt }}
{{ end }}Answer:
{{ .Response }}
{{- end }}

View File

@@ -0,0 +1,16 @@
{{- if .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if eq $index 0 }}<<SYS>>
{{- if $.System }}
{{ $.System }}
{{ end }}<</SYS>>
{{ end }}{{ .Content }}
{{- else }} [/INST] {{ .Content }}</s><s>
{{- end }}
{{- end }} [/INST]
{{- else }}
[INST] <<SYS>>{{ .System }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }}
{{- end }}

View File

@@ -0,0 +1,19 @@
{{- if .Messages }}
{{- if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ .Content }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>
{{ else }}
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>
{{- end }}

20
template/magicoder.gotmpl Normal file
View File

@@ -0,0 +1,20 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}@@ Instruction
{{- else if eq .Role "assistant" }}@@ Response
{{- end }}
{{ .Content }}
{{ end }}@@ Response
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}@@ Instruction
{{ .Prompt }}
{{ end }}@@ Response
{{ .Response }}
{{- end }}

View File

@@ -0,0 +1,9 @@
{{- if .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}
{{ end }}{{ .Content }}
{{- else if eq .Role "assistant" }}[/INST] {{ .Content }}</s>
{{- end }}
{{- end }}[/INST]
{{- else }}[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}
{{- end }}

11
template/openchat.gotmpl Normal file
View File

@@ -0,0 +1,11 @@
{{- if .Messages }}
{{- if .System }}GPT Correct System: {{ .System }}<|end_of_turn|>
{{- end }}
{{- range .Messages }}GPT Correct
{{- if eq .Role "user" }} User:
{{- else if eq .Role "assistant" }} Assistant:
{{- end }} {{ .Content }}<|end_of_turn|>
{{- end }}GPT Correct Assistant:
{{- else }}
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
{{- end }}

15
template/phi-3.gotmpl Normal file
View File

@@ -0,0 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}
{{- range .Messages }}<|{{ .Role }}|>
{{ .Content }}<|end|>
{{ end }}<|assistant|>
{{ else }}
{{ if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}<|end|>
{{ end }}<|assistant|>
{{ .Response }}<|end|>
{{- end }}

View File

@@ -0,0 +1,22 @@
{{- if .Messages }}
{{- if .System }}### System:
{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### User:
{{ .Content }}
{{ else if eq .Role "assistant" }}### Assistant:
{{ .Content }}</s>
{{ end }}
{{ end }}### Assistant:
{{ else }}
{{ if .System }}### System:
{{ .System }}
{{ end }}{{ if .Prompt }}### User:
{{ .Prompt }}
{{ end }}### Assistant:
{{ .Response }}
{{- end }}

View File

@@ -0,0 +1,24 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction
{{ .Content }}
{{ else if eq .Role "assistant" }}### Response
{{ .Content }}<|endoftext|>
{{ end }}
{{- end }}### Response
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction
{{ .Prompt }}
{{ end }}### Response
{{ .Response }}<|endoftext|>
{{- end }}

288
template/template.go Normal file
View File

@@ -0,0 +1,288 @@
package template
import (
"bytes"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"slices"
"strings"
"sync"
"text/template"
"text/template/parse"
"github.com/agnivade/levenshtein"
"github.com/ollama/ollama/api"
"golang.org/x/exp/maps"
)
//go:embed index.json
var indexBytes []byte
//go:embed *.gotmpl
var templatesFS embed.FS
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
var templates []*named
if err := json.Unmarshal(indexBytes, &templates); err != nil {
return nil, err
}
for _, t := range templates {
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
if err != nil {
return nil, err
}
// normalize line endings
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
}
return templates, nil
})
type named struct {
Name string `json:"name"`
Template string `json:"template"`
Bytes []byte
}
func (t named) Reader() io.Reader {
return bytes.NewReader(t.Bytes)
}
func Named(s string) (*named, error) {
templates, err := templatesOnce()
if err != nil {
return nil, err
}
var template *named
score := math.MaxInt
for _, t := range templates {
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
score = s
template = t
}
}
if score < 100 {
return template, nil
}
return nil, errors.New("no matching template found")
}
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
type Template struct {
*template.Template
raw string
}
// response is a template node that can be added to templates that don't already have one
var response = parse.ActionNode{
NodeType: parse.NodeAction,
Pipe: &parse.PipeNode{
NodeType: parse.NodePipe,
Cmds: []*parse.CommandNode{
{
NodeType: parse.NodeCommand,
Args: []parse.Node{
&parse.FieldNode{
NodeType: parse.NodeField,
Ident: []string{"Response"},
},
},
},
},
},
}
func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero")
tmpl, err := tmpl.Parse(s)
if err != nil {
return nil, err
}
t := Template{Template: tmpl, raw: s}
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
// touch up the template and append {{ .Response }}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
}
return &t, nil
}
func (t *Template) String() string {
return t.raw
}
func (t *Template) Vars() []string {
var vars []string
for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes {
vars = append(vars, parseNode(n)...)
}
}
set := make(map[string]struct{})
for _, n := range vars {
set[strings.ToLower(n)] = struct{}{}
}
vars = maps.Keys(set)
slices.Sort(vars)
return vars
}
type Values struct {
Messages []api.Message
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, collated := collate(v.Messages)
if slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": collated,
})
}
var b bytes.Buffer
var prompt, response string
for i, m := range collated {
if m.Role == "user" {
prompt = m.Content
} else {
response = m.Content
}
if i != len(collated)-1 && prompt != "" && response != "" {
if err := t.Template.Execute(&b, map[string]any{
"System": "",
"Prompt": prompt,
"Response": response,
}); err != nil {
return err
}
prompt = ""
response = ""
}
}
var cut bool
tree := t.Template.Copy()
// for the last message, cut everything after "{{ .Response }}"
tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
if slices.Contains(parseNode(n), "Response") {
cut = true
}
return cut
})
if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
}); err != nil {
return err
}
_, err := io.Copy(w, &b)
return err
}
type messages []*api.Message
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func collate(msgs []api.Message) (system string, collated messages) {
var n int
for i := range msgs {
msg := msgs[i]
if msg.Role == "system" {
if system != "" {
system += "\n\n"
}
system += msg.Content
continue
}
for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {
msg.Content = strings.TrimSpace("[img] " + msg.Content)
}
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
n++
}
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
collated = append(collated, &msg)
}
}
return
}
func parseNode(n parse.Node) []string {
switch n := n.(type) {
case *parse.ActionNode:
return parseNode(n.Pipe)
case *parse.IfNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.RangeNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.WithNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.PipeNode:
var names []string
for _, c := range n.Cmds {
for _, a := range c.Args {
names = append(names, parseNode(a)...)
}
}
return names
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, parseNode(n)...)
}
return names
case *parse.FieldNode:
return n.Ident
case *parse.TemplateNode:
return parseNode(n.Pipe)
}
return nil
}

310
template/template_test.go Normal file
View File

@@ -0,0 +1,310 @@
package template
import (
"bufio"
"bytes"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
func TestNamed(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var ss map[string]string
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
t.Fatal(err)
}
for k, v := range ss {
t.Run(k, func(t *testing.T) {
kv := llm.KV{"tokenizer.chat_template": v}
s := kv.ChatTemplate()
r, err := Named(s)
if err != nil {
t.Fatal(err)
}
if r.Name != k {
t.Errorf("expected %q, got %q", k, r.Name)
}
var b bytes.Buffer
if _, err := io.Copy(&b, r.Reader()); err != nil {
t.Fatal(err)
}
tmpl, err := Parse(b.String())
if err != nil {
t.Fatal(err)
}
if tmpl.Tree.Root.String() == "" {
t.Errorf("empty %s template", k)
}
})
}
}
}
func TestTemplate(t *testing.T) {
cases := make(map[string][]api.Message)
for _, mm := range [][]api.Message{
{
{Role: "user", Content: "Hello, how are you?"},
},
{
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
{Role: "user", Content: "I'd like to show off how chat templating works!"},
},
{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
{Role: "user", Content: "I'd like to show off how chat templating works!"},
},
} {
var roles []string
for _, m := range mm {
roles = append(roles, m.Role)
}
cases[strings.Join(roles, "-")] = mm
}
matches, err := filepath.Glob("*.gotmpl")
if err != nil {
t.Fatal(err)
}
for _, match := range matches {
t.Run(match, func(t *testing.T) {
bts, err := os.ReadFile(match)
if err != nil {
t.Fatal(err)
}
tmpl, err := Parse(string(bts))
if err != nil {
t.Fatal(err)
}
for n, tt := range cases {
t.Run(n, func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
t.Fatal(err)
}
expect, err := os.ReadFile(filepath.Join("testdata", match, n))
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
})
}
}
func TestParse(t *testing.T) {
cases := []struct {
template string
vars []string
}{
{"{{ .Prompt }}", []string{"prompt", "response"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
tmpl, err := Parse(tt.template)
if err != nil {
t.Fatal(err)
}
vars := tmpl.Vars()
if !slices.Equal(tt.vars, vars) {
t.Errorf("expected %v, got %v", tt.vars, vars)
}
})
}
}
func TestExecuteWithMessages(t *testing.T) {
type template struct {
name string
template string
}
cases := []struct {
name string
templates []template
values Values
expected string
}{
{
"mistral",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
},
{
"mistral system",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
What is your name?[/INST] `,
},
{
"chatml",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
`},
{"messages", `
{{- range $index, $_ := .Messages }}
{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>assistant
`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`<|im_start|>user
Hello friend!<|im_end|>
<|im_start|>assistant
Hello human!<|im_end|>
<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
What is your name?<|im_end|>
<|im_start|>assistant
`,
},
{
"moondream",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
{{ end }}Answer: {{ .Response }}
`},
{"messages", `
{{- range .Messages }}
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
{{- end }}
{{- end }}Answer: `},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
{Role: "assistant", Content: "It's a hot dog."},
{Role: "user", Content: "What's in _this_ image?"},
{Role: "user", Images: []api.ImageData{[]byte("")}},
{Role: "user", Content: "Is it a hot dog?"},
},
},
`Question: [img-0] What's in this image?
Answer: It's a hot dog.
Question: What's in _this_ image?
[img-1]
Is it a hot dog?
Answer: `,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
for _, ttt := range tt.templates {
t.Run(ttt.name, func(t *testing.T) {
tmpl, err := Parse(ttt.template)
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
}
if b.String() != tt.expected {
t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
}
})
}
})
}
}

View File

@@ -0,0 +1 @@
<start_system>You are a helpful assistant.<end_message><start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

1
template/testdata/alfred.gotmpl/user vendored Normal file
View File

@@ -0,0 +1 @@
<start_user>Hello, how are you?<end_message><start_assistant>

View File

@@ -0,0 +1 @@
<start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

View File

@@ -0,0 +1,10 @@
You are a helpful assistant.### Instruction:
Hello, how are you?
### Response:
I'm doing great. How can I help you today?
### Instruction:
I'd like to show off how chat templating works!
### Response:

4
template/testdata/alpaca.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,4 @@
### Instruction:
Hello, how are you?
### Response:

View File

@@ -0,0 +1,10 @@
### Instruction:
Hello, how are you?
### Response:
I'm doing great. How can I help you today?
### Instruction:
I'd like to show off how chat templating works!
### Response:

View File

@@ -0,0 +1,9 @@
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I'm doing great. How can I help you today?<|im_end|>
<|im_start|>user
I'd like to show off how chat templating works!<|im_end|>
<|im_start|>assistant

3
template/testdata/chatml.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,3 @@
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant

Some files were not shown because too many files have changed in this diff Show More