Compare commits
248 Commits
mattw/quan
...
v0.1.23
Author | SHA1 | Date | |
---|---|---|---|
![]() |
09a6f76f4c | ||
![]() |
e135167484 | ||
![]() |
38296ab352 | ||
![]() |
f43dea68d1 | ||
![]() |
e1f50377f4 | ||
![]() |
7913104527 | ||
![]() |
bfbf2f7cf7 | ||
![]() |
fe3cbd014f | ||
![]() |
3d6f48507a | ||
![]() |
f3761405c8 | ||
![]() |
e49dc9f3d8 | ||
![]() |
d125510b4b | ||
![]() |
1ca386aa9e | ||
![]() |
fb56988014 | ||
![]() |
d046bee790 | ||
![]() |
f11bf0740b | ||
![]() |
8450bf66e6 | ||
![]() |
b4e11be8ef | ||
![]() |
a896079705 | ||
![]() |
583950c828 | ||
![]() |
8ac08a0eec | ||
![]() |
60f47be64c | ||
![]() |
6e56077ada | ||
![]() |
98ae9467bb | ||
![]() |
b7a24af083 | ||
![]() |
c8b1f2369e | ||
![]() |
72b12c3be7 | ||
![]() |
0632dff3f8 | ||
![]() |
509e2dec8a | ||
![]() |
78a48de804 | ||
![]() |
e7dbb00331 | ||
![]() |
2e06ed01d5 | ||
![]() |
4072b5879b | ||
![]() |
15562e887d | ||
![]() |
f2245c7c77 | ||
![]() |
e4b9b72f2a | ||
![]() |
311f8e0c3f | ||
![]() |
f07f8b7a9e | ||
![]() |
e02ecfb6c8 | ||
![]() |
c8059b4dcf | ||
![]() |
59d87127f5 | ||
![]() |
b5cf31b460 | ||
![]() |
cc4915e262 | ||
![]() |
667a2ba18a | ||
![]() |
e054ebe059 | ||
![]() |
9d3dcfd0ec | ||
![]() |
6e0ea5ecc8 | ||
![]() |
a47d8b2557 | ||
![]() |
30c43c285c | ||
![]() |
23a7ea593b | ||
![]() |
75c44aa319 | ||
![]() |
9d7b5d6c91 | ||
![]() |
5d9c4a5f5a | ||
![]() |
197e420a97 | ||
![]() |
a34e1ad3cf | ||
![]() |
2ae0556292 | ||
![]() |
5be9bdd444 | ||
![]() |
b706794905 | ||
![]() |
a8c5413d06 | ||
![]() |
5580de4571 | ||
![]() |
946431d5b0 | ||
![]() |
0610126049 | ||
![]() |
3ebd6a83fc | ||
![]() |
a64570dcae | ||
![]() |
7c40a67841 | ||
![]() |
e64b5b07a2 | ||
![]() |
9e1e295cdc | ||
![]() |
a643823f86 | ||
![]() |
8e5d359a03 | ||
![]() |
a170888dd4 | ||
![]() |
cd22855ef8 | ||
![]() |
013fd07139 | ||
![]() |
f63dc2db5c | ||
![]() |
eaa5a396d9 | ||
![]() |
8ed22f5d72 | ||
![]() |
987c16b2f7 | ||
![]() |
950f636d64 | ||
![]() |
4458efb73a | ||
![]() |
ceea599494 | ||
![]() |
3005ec74b3 | ||
![]() |
0759d8996e | ||
![]() |
0f5b843319 | ||
![]() |
ffaf52e1e9 | ||
![]() |
940b10b036 | ||
![]() |
3bc28736cd | ||
![]() |
93a756266c | ||
![]() |
a0a829bf7a | ||
![]() |
730dcfcc7a | ||
![]() |
27a2d5af54 | ||
![]() |
5f81a33f43 | ||
![]() |
6225fde046 | ||
![]() |
069184562b | ||
![]() |
5576bb2348 | ||
![]() |
2738837786 | ||
![]() |
ec3764538d | ||
![]() |
df54c723ae | ||
![]() |
fa8c990e58 | ||
![]() |
da72235ebf | ||
![]() |
89c4aee29e | ||
![]() |
a447a083f2 | ||
![]() |
f32ea81b21 | ||
![]() |
681a914990 | ||
![]() |
4c54f0ddeb | ||
![]() |
c08dfaa23d | ||
![]() |
3b76e736ae | ||
![]() |
552db98bf1 | ||
![]() |
fdcdfef620 | ||
![]() |
6a042438af | ||
![]() |
dc88cc3981 | ||
![]() |
62976087c6 | ||
![]() |
344342abdf | ||
![]() |
eb76f3e379 | ||
![]() |
d017e3d0a6 | ||
![]() |
aac9ab4db7 | ||
![]() |
1f5b7ff976 | ||
![]() |
e299831e2c | ||
![]() |
745b5934fa | ||
![]() |
a38d88d828 | ||
![]() |
abec7f06e5 | ||
![]() |
e5da190bac | ||
![]() |
ecbfc0182f | ||
![]() |
fedd705aea | ||
![]() |
82ee019bfc | ||
![]() |
ad9dbc2a04 | ||
![]() |
fccdf4c635 | ||
![]() |
d450fb1d1e | ||
![]() |
df40b11d03 | ||
![]() |
9cd20b0ec8 | ||
![]() |
b992bf65fc | ||
![]() |
1b249748ab | ||
![]() |
cbe2adc78a | ||
![]() |
d5a7353357 | ||
![]() |
96cfb62641 | ||
![]() |
7d00b5d110 | ||
![]() |
795674dd90 | ||
![]() |
e282bdccdd | ||
![]() |
d9bfb2f08f | ||
![]() |
598d6d5572 | ||
![]() |
a897e833b8 | ||
![]() |
eef50accb4 | ||
![]() |
05d53de7a1 | ||
![]() |
8795447dad | ||
![]() |
b3035112a1 | ||
![]() |
95ad9a9fc8 | ||
![]() |
3ca5f69ce8 | ||
![]() |
cfa6337960 | ||
![]() |
f4bf1d514f | ||
![]() |
557110d0ba | ||
![]() |
2ecb247276 | ||
![]() |
288ef8ff95 | ||
![]() |
4cf17990f7 | ||
![]() |
27331ae3a8 | ||
![]() |
b6c0ef1e70 | ||
![]() |
356d178f6e | ||
![]() |
eaed6f8c45 | ||
![]() |
6a5bfc2ed6 | ||
![]() |
cf29bd2d72 | ||
![]() |
905862e17b | ||
![]() |
565f8a3c44 | ||
![]() |
5121b7ac9c | ||
![]() |
a70262c6b2 | ||
![]() |
40a0a90a88 | ||
![]() |
cbe20c4375 | ||
![]() |
5ffbbea1d7 | ||
![]() |
3773fb6465 | ||
![]() |
7427fa1387 | ||
![]() |
f84537e0e0 | ||
![]() |
d2be6387c9 | ||
![]() |
d7af35d3d0 | ||
![]() |
defc1dbd6e | ||
![]() |
de2fbdec99 | ||
![]() |
f5faf79aa1 | ||
![]() |
f4f939de28 | ||
![]() |
39928a42e8 | ||
![]() |
d88c527be3 | ||
![]() |
3bc8b9832b | ||
![]() |
ab6be852c7 | ||
![]() |
052b33b81b | ||
![]() |
8da7bef05f | ||
![]() |
b24e8d17b2 | ||
![]() |
f83881390f | ||
![]() |
ac70ab6761 | ||
![]() |
3c49c3ab0d | ||
![]() |
9754ae4c89 | ||
![]() |
224fbf2795 | ||
![]() |
2c6e8f5248 | ||
![]() |
34344d801c | ||
![]() |
e868c8a5c7 | ||
![]() |
c336693f07 | ||
![]() |
e89dc1d54b | ||
![]() |
1961a81f03 | ||
![]() |
8a8c7e7f8d | ||
![]() |
6df83e6daa | ||
![]() |
f921e2696e | ||
![]() |
4a33cede20 | ||
![]() |
f95d2f25f3 | ||
![]() |
2b9892a808 | ||
![]() |
2bb2bdd5d4 | ||
![]() |
acfc376efd | ||
![]() |
997253143f | ||
![]() |
62023177f6 | ||
![]() |
6164f378f2 | ||
![]() |
f387e9631b | ||
![]() |
6566387ae3 | ||
![]() |
37708931fb | ||
![]() |
f6cb0a553c | ||
![]() |
2680078c13 | ||
![]() |
f1b7e5f560 | ||
![]() |
cb534e6ac2 | ||
![]() |
58ce2d8273 | ||
![]() |
18ddf6d57d | ||
![]() |
61e6502449 | ||
![]() |
08f1e18965 | ||
![]() |
7e8f7c8358 | ||
![]() |
3f3eb19a3b | ||
![]() |
059ae4585e | ||
![]() |
6347f501ca | ||
![]() |
5feec959ad | ||
![]() |
dbdd50b283 | ||
![]() |
d74ce6bd4f | ||
![]() |
57942b4676 | ||
![]() |
e0d05b0f1e | ||
![]() |
2d9dd14f27 | ||
![]() |
1caa56128f | ||
![]() |
0101e76dbe | ||
![]() |
2ef9352b94 | ||
![]() |
5580ae2472 | ||
![]() |
3a9f447141 | ||
![]() |
9c2941e61b | ||
![]() |
238ac5e765 | ||
![]() |
4f4980b66b | ||
![]() |
22e93efa41 | ||
![]() |
2909dce894 | ||
![]() |
df32537312 | ||
![]() |
3367b5f3df | ||
![]() |
46edbbc518 | ||
![]() |
d2ff18cd6b | ||
![]() |
df086d3c8c | ||
![]() |
8baaaa39c0 | ||
![]() |
f9961c70ae | ||
![]() |
cd8fad3398 | ||
![]() |
9983fa5f4e | ||
![]() |
dfda91c2ee | ||
![]() |
fac9060da5 | ||
![]() |
a554616f8e | ||
![]() |
77d96da94b | ||
![]() |
0d6e3565ae | ||
![]() |
e201efa14b |
@@ -2,7 +2,7 @@
|
||||
ollama
|
||||
app
|
||||
dist
|
||||
llm/llama.cpp/gguf
|
||||
llm/llama.cpp
|
||||
.env
|
||||
.cache
|
||||
test_data
|
162
.github/workflows/test.yaml
vendored
Normal file
162
.github/workflows/test.yaml
vendored
Normal file
@@ -0,0 +1,162 @@
|
||||
name: test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
generate:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
arch: [amd64, arm64]
|
||||
exclude:
|
||||
- os: ubuntu-latest
|
||||
arch: arm64
|
||||
- os: windows-latest
|
||||
arch: arm64
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.21'
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: go generate -x ./...
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
generate-cuda:
|
||||
strategy:
|
||||
matrix:
|
||||
cuda-version:
|
||||
- '11.8.0'
|
||||
runs-on: ubuntu-latest
|
||||
container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04
|
||||
steps:
|
||||
- run: |
|
||||
apt-get update && apt-get install -y git build-essential curl
|
||||
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
|
||||
| tar -zx -C /usr --strip-components 1
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.21'
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
git config --global --add safe.directory /__w/ollama/ollama
|
||||
go generate -x ./...
|
||||
env:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: cuda-${{ matrix.cuda-version }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
generate-rocm:
|
||||
strategy:
|
||||
matrix:
|
||||
rocm-version:
|
||||
- '5.7.1'
|
||||
- '6.0'
|
||||
runs-on: ubuntu-latest
|
||||
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
||||
steps:
|
||||
- run: |
|
||||
apt-get update && apt-get install -y git build-essential curl rocm-libs
|
||||
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
|
||||
| tar -zx -C /usr --strip-components 1
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.21'
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
git config --global --add safe.directory /__w/ollama/ollama
|
||||
go generate -x ./...
|
||||
env:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: rocm-${{ matrix.rocm-version }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
lint:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
arch: [amd64, arm64]
|
||||
exclude:
|
||||
- os: ubuntu-latest
|
||||
arch: arm64
|
||||
- os: windows-latest
|
||||
arch: arm64
|
||||
- os: macos-latest
|
||||
arch: amd64
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
CGO_ENABLED: "1"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.21'
|
||||
cache: false
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib
|
||||
touch llm/llama.cpp/ggml-metal.metal
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
- uses: golangci/golangci-lint-action@v3
|
||||
test:
|
||||
needs: generate
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
arch: [amd64]
|
||||
exclude:
|
||||
- os: ubuntu-latest
|
||||
arch: arm64
|
||||
- os: windows-latest
|
||||
arch: arm64
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
CGO_ENABLED: "1"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.21'
|
||||
cache: true
|
||||
- run: go get
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
path: llm/llama.cpp/build
|
||||
- run: go build
|
||||
- run: go test -v ./...
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-binaries
|
||||
path: ollama
|
9
.gitmodules
vendored
9
.gitmodules
vendored
@@ -1,5 +1,4 @@
|
||||
[submodule "llm/llama.cpp/gguf"]
|
||||
path = llm/llama.cpp/gguf
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
ignore = dirty
|
||||
shallow = true
|
||||
[submodule "llama.cpp"]
|
||||
path = llm/llama.cpp
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
shallow = true
|
27
.golangci.yaml
Normal file
27
.golangci.yaml
Normal file
@@ -0,0 +1,27 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
linters:
|
||||
enable:
|
||||
- asasalint
|
||||
- bidichk
|
||||
- bodyclose
|
||||
- containedctx
|
||||
- contextcheck
|
||||
- exportloopref
|
||||
- gocheckcompilerdirectives
|
||||
# FIXME: for some reason this errors on windows
|
||||
# - gofmt
|
||||
# - goimports
|
||||
- misspell
|
||||
- nilerr
|
||||
- unused
|
||||
linters-settings:
|
||||
errcheck:
|
||||
# exclude the following functions since we don't generally
|
||||
# need to be concerned with the returned errors
|
||||
exclude-functions:
|
||||
- encoding/binary.Read
|
||||
- (*os.File).Seek
|
||||
- (*bufio.Writer).WriteString
|
||||
- (*github.com/spf13/pflag.FlagSet).Set
|
||||
- (*github.com/jmorganca/ollama/llm.readSeekOffset).Seek
|
138
Dockerfile
138
Dockerfile
@@ -1,27 +1,135 @@
|
||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
ARG GOLANG_VERSION=1.21.3
|
||||
ARG CMAKE_VERSION=3.22.1
|
||||
ARG CUDA_VERSION=11.3.1
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
# Copy the minimal context we need to run the generate scripts
|
||||
FROM scratch AS llm-code
|
||||
COPY .git .git
|
||||
COPY .gitmodules .gitmodules
|
||||
COPY llm llm
|
||||
|
||||
FROM --platform=linux/amd64 nvidia/cuda:$CUDA_VERSION-devel-centos7 AS cuda-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
|
||||
ARG CMAKE_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete AS rocm-5-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
ENV LIBRARY_PATH /opt/amdgpu/lib64
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
ARG AMDGPU_TARGETS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete AS rocm-6-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
ENV LIBRARY_PATH /opt/amdgpu/lib64
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
ARG AMDGPU_TARGETS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 centos:7 AS cpu-builder-amd64
|
||||
ARG CMAKE_VERSION
|
||||
ARG GOLANG_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 centos:7 AS cpu-build-arm64
|
||||
ARG CMAKE_VERSION
|
||||
ARG GOLANG_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
# Note, we only build the "base" CPU variant on arm since avx/avx2 are x86 features
|
||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
|
||||
# Intermediate stage used for ./scripts/build_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64
|
||||
ENV CGO_ENABLED 1
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||
RUN apt-get update && apt-get install -y git build-essential cmake
|
||||
ADD https://dl.google.com/go/go1.21.3.linux-$TARGETARCH.tar.gz /tmp/go1.21.3.tar.gz
|
||||
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.3.tar.gz
|
||||
|
||||
COPY . .
|
||||
ENV GOARCH=$TARGETARCH
|
||||
ENV GOFLAGS=$GOFLAGS
|
||||
RUN /usr/local/go/bin/go generate ./... \
|
||||
&& /usr/local/go/bin/go build .
|
||||
COPY --from=cpu_avx-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=cpu_avx2-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=cuda-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=rocm-5-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=rocm-6-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
RUN go build .
|
||||
|
||||
FROM ubuntu:22.04
|
||||
# Intermediate stage used for ./scripts/build_linux.sh
|
||||
FROM --platform=linux/arm64 cpu-build-arm64 AS build-arm64
|
||||
ENV CGO_ENABLED 1
|
||||
ARG GOLANG_VERSION
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||
COPY . .
|
||||
COPY --from=cuda-build-arm64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
RUN go build .
|
||||
|
||||
# Runtime stages
|
||||
FROM --platform=linux/amd64 ubuntu:22.04 as runtime-amd64
|
||||
RUN apt-get update && apt-get install -y ca-certificates
|
||||
COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
|
||||
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
|
||||
FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64
|
||||
RUN apt-get update && apt-get install -y ca-certificates
|
||||
COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
|
||||
|
||||
# Radeon images are much larger so we keep it distinct from the CPU/CUDA image
|
||||
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete as runtime-rocm
|
||||
RUN update-pciids
|
||||
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
|
||||
EXPOSE 11434
|
||||
ENV OLLAMA_HOST 0.0.0.0
|
||||
|
||||
# set some environment variable for better NVIDIA compatibility
|
||||
ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENTRYPOINT ["/bin/ollama"]
|
||||
CMD ["serve"]
|
||||
|
||||
FROM runtime-$TARGETARCH
|
||||
EXPOSE 11434
|
||||
ENV OLLAMA_HOST 0.0.0.0
|
||||
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||
|
||||
|
@@ -1,74 +0,0 @@
|
||||
# Ubuntu 20.04 amd64 dependencies
|
||||
FROM --platform=linux/amd64 ubuntu:20.04 AS base-amd64
|
||||
ARG CUDA_VERSION=11.3.1-1
|
||||
ARG CMAKE_VERSION=3.22.1
|
||||
# ROCm only supports amd64
|
||||
ARG ROCM_VERSION=6.0
|
||||
ARG CLBLAST_VER=1.6.1
|
||||
|
||||
# Note: https://rocm.docs.amd.com/en/latest/release/user_kernel_space_compat_matrix.html
|
||||
RUN apt-get update && \
|
||||
apt-get install -y wget gnupg && \
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin && \
|
||||
mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600 && \
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
|
||||
echo "deb [by-hash=no] https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/ /" > /etc/apt/sources.list.d/cuda.list && \
|
||||
wget "https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-x86_64.sh" -O /tmp/cmake-installer.sh && \
|
||||
chmod +x /tmp/cmake-installer.sh && /tmp/cmake-installer.sh --skip-license --prefix=/usr && \
|
||||
mkdir --parents --mode=0755 /etc/apt/keyrings && \
|
||||
wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | gpg --dearmor > /etc/apt/keyrings/rocm.gpg && \
|
||||
echo "deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/${ROCM_VERSION} focal main" > /etc/apt/sources.list.d/rocm.list && \
|
||||
echo "Package: *" > /etc/apt/preferences.d/rocm-pin-600 && \
|
||||
echo "Pin: release o=repo.radeon.com" >> /etc/apt/preferences.d/rocm-pin-600 && \
|
||||
echo "Pin-Priority: 600" >> /etc/apt/preferences.d/rocm-pin-600 && \
|
||||
apt-get update && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get -y install cuda=${CUDA_VERSION} rocm-hip-libraries rocm-device-libs rocm-libs rocm-ocl-icd rocm-hip-sdk rocm-hip-libraries rocm-cmake rocm-clang-ocl rocm-dev
|
||||
|
||||
# CLBlast
|
||||
RUN wget -qO- https://github.com/CNugteren/CLBlast/archive/refs/tags/${CLBLAST_VER}.tar.gz | tar zxv -C /tmp/ && \
|
||||
cd /tmp/CLBlast-${CLBLAST_VER} && mkdir build && cd build && cmake .. && make && make install
|
||||
|
||||
ENV ROCM_PATH=/opt/rocm
|
||||
|
||||
# Ubuntu 22.04 arm64 dependencies
|
||||
FROM --platform=linux/arm64 ubuntu:20.04 AS base-arm64
|
||||
ARG CUDA_VERSION=11.3.1-1
|
||||
ARG CMAKE_VERSION=3.27.6
|
||||
RUN apt-get update && \
|
||||
apt-get install -y wget gnupg && \
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/sbsa/cuda-ubuntu2004.pin && \
|
||||
mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600 && \
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/sbsa//3bf863cc.pub && \
|
||||
echo "deb [by-hash=no] https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/sbsa/ /" > /etc/apt/sources.list.d/cuda.list && \
|
||||
wget "https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-aarch64.sh" -O /tmp/cmake-installer.sh && \
|
||||
chmod +x /tmp/cmake-installer.sh && /tmp/cmake-installer.sh --skip-license --prefix=/usr && \
|
||||
apt-get update && \
|
||||
apt-cache madison cuda && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get -y install cuda=${CUDA_VERSION}
|
||||
|
||||
FROM base-${TARGETARCH}
|
||||
ARG TARGETARCH
|
||||
ARG GOFLAGS="'-ldflags -w -s'"
|
||||
ARG CGO_CFLAGS
|
||||
ARG GOLANG_VERSION=1.21.3
|
||||
|
||||
# Common toolchain
|
||||
RUN apt-get update && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y gcc-10 g++-10 cpp-10 git ocl-icd-opencl-dev && \
|
||||
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 --slave /usr/bin/g++ g++ /usr/bin/g++-10 --slave /usr/bin/gcov gcov /usr/bin/gcov-10
|
||||
|
||||
# install go
|
||||
ADD https://dl.google.com/go/go${GOLANG_VERSION}.linux-$TARGETARCH.tar.gz /tmp/go${GOLANG_VERSION}.tar.gz
|
||||
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go${GOLANG_VERSION}.tar.gz
|
||||
|
||||
# build the final binary
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||
COPY . .
|
||||
|
||||
ENV GOOS=linux
|
||||
ENV GOARCH=$TARGETARCH
|
||||
ENV GOFLAGS=$GOFLAGS
|
||||
ENV CGO_CFLAGS=${CGO_CFLAGS}
|
||||
|
||||
RUN /usr/local/go/bin/go generate ./... && \
|
||||
/usr/local/go/bin/go build .
|
25
README.md
25
README.md
@@ -1,8 +1,5 @@
|
||||
<div align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" height="200px" srcset="https://github.com/jmorganca/ollama/assets/3325447/56ea1849-1284-4645-8970-956de6e51c3c">
|
||||
<img alt="logo" height="200px" src="https://github.com/jmorganca/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</picture>
|
||||
<img alt="ollama" height="200px" src="https://github.com/jmorganca/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</div>
|
||||
|
||||
# Ollama
|
||||
@@ -31,6 +28,11 @@ curl https://ollama.ai/install.sh | sh
|
||||
|
||||
The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `ollama/ollama` is available on Docker Hub.
|
||||
|
||||
### Libraries
|
||||
|
||||
- [ollama-python](https://github.com/ollama/ollama-python)
|
||||
- [ollama-js](https://github.com/ollama/ollama-js)
|
||||
|
||||
## Quickstart
|
||||
|
||||
To run and chat with [Llama 2](https://ollama.ai/library/llama2):
|
||||
@@ -198,18 +200,21 @@ brew install cmake go
|
||||
```
|
||||
|
||||
Then generate dependencies:
|
||||
|
||||
```
|
||||
go generate ./...
|
||||
```
|
||||
|
||||
Then build the binary:
|
||||
|
||||
```
|
||||
go build .
|
||||
```
|
||||
|
||||
More detailed instructions can be found in the [developer guide](https://github.com/jmorganca/ollama/blob/main/docs/development.md)
|
||||
|
||||
|
||||
### Running local builds
|
||||
|
||||
Next, start the server:
|
||||
|
||||
```
|
||||
@@ -251,6 +256,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
## Community Integrations
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
@@ -263,7 +269,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Amica](https://github.com/semperai/amica)
|
||||
- [chatd](https://github.com/BruceMacD/chatd)
|
||||
- [Ollama-SwiftUI](https://github.com/kghandour/Ollama-SwiftUI)
|
||||
|
||||
- [MindMac](https://mindmac.app)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -292,6 +298,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html)
|
||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
||||
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
|
||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
||||
- [Ollama4j for Java](https://github.com/amithkoujalgi/ollama4j)
|
||||
- [ModelFusion Typescript Library](https://modelfusion.dev/integration/model-provider/ollama)
|
||||
@@ -299,6 +306,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama for Dart](https://github.com/breitburg/dart-ollama)
|
||||
- [Ollama for Laravel](https://github.com/cloudstudio/ollama-laravel)
|
||||
- [LangChainDart](https://github.com/davidmigloz/langchain_dart)
|
||||
- [Semantic Kernel - Python](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama)
|
||||
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
|
||||
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
|
||||
|
||||
### Mobile
|
||||
|
||||
@@ -319,3 +329,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
|
||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
|
||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
||||
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
|
||||
|
@@ -309,6 +309,13 @@ func (c *Client) Heartbeat(ctx context.Context) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||
var resp EmbeddingResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
||||
if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil {
|
||||
|
284
api/client.py
284
api/client.py
@@ -1,284 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
BASE_URL = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
|
||||
|
||||
# Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses.
|
||||
# The final response object will include statistics and additional data from the request. Use the callback function to override
|
||||
# the default handler.
|
||||
def generate(model_name, prompt, system=None, template=None, format="", context=None, options=None, callback=None):
|
||||
try:
|
||||
url = f"{BASE_URL}/api/generate"
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"system": system,
|
||||
"template": template,
|
||||
"context": context,
|
||||
"options": options,
|
||||
"format": format,
|
||||
}
|
||||
|
||||
# Remove keys with None values
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
with requests.post(url, json=payload, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# Creating a variable to hold the context history of the final chunk
|
||||
final_context = None
|
||||
|
||||
# Variable to hold concatenated response strings if no callback is provided
|
||||
full_response = ""
|
||||
|
||||
# Iterating over the response line by line and displaying the details
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
# Parsing each line (JSON chunk) and extracting the details
|
||||
chunk = json.loads(line)
|
||||
|
||||
# If a callback function is provided, call it with the chunk
|
||||
if callback:
|
||||
callback(chunk)
|
||||
else:
|
||||
# If this is not the last chunk, add the "response" field value to full_response and print it
|
||||
if not chunk.get("done"):
|
||||
response_piece = chunk.get("response", "")
|
||||
full_response += response_piece
|
||||
print(response_piece, end="", flush=True)
|
||||
|
||||
# Check if it's the last chunk (done is true)
|
||||
if chunk.get("done"):
|
||||
final_context = chunk.get("context")
|
||||
|
||||
# Return the full response and the final context
|
||||
return full_response, final_context
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
# Create a blob file on the server if it doesn't exist.
|
||||
def create_blob(digest, file_path):
|
||||
url = f"{BASE_URL}/api/blobs/{digest}"
|
||||
|
||||
# Check if the blob exists
|
||||
response = requests.head(url)
|
||||
if response.status_code != 404:
|
||||
return # Blob already exists, no need to upload
|
||||
response.raise_for_status()
|
||||
|
||||
# Upload the blob
|
||||
with open(file_path, 'rb') as file_data:
|
||||
requests.post(url, data=file_data)
|
||||
|
||||
|
||||
# Create a model from a Modelfile. Use the callback function to override the default handler.
|
||||
def create(model_name, filename, callback=None):
|
||||
try:
|
||||
file_path = Path(filename).expanduser().resolve()
|
||||
processed_lines = []
|
||||
|
||||
# Read and process the modelfile
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f:
|
||||
# Skip empty or whitespace-only lines
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
command, args = line.split(maxsplit=1)
|
||||
|
||||
if command.upper() in ["FROM", "ADAPTER"]:
|
||||
path = Path(args.strip()).expanduser()
|
||||
|
||||
# Check if path is relative and resolve it
|
||||
if not path.is_absolute():
|
||||
path = (file_path.parent / path)
|
||||
|
||||
# Skip if file does not exist for "model", this is handled by the server
|
||||
if not path.exists():
|
||||
processed_lines.append(line)
|
||||
continue
|
||||
|
||||
# Calculate SHA-256 hash
|
||||
with open(path, 'rb') as bin_file:
|
||||
hash = hashlib.sha256()
|
||||
hash.update(bin_file.read())
|
||||
blob = f"sha256:{hash.hexdigest()}"
|
||||
|
||||
# Add the file to the remote server
|
||||
create_blob(blob, path)
|
||||
|
||||
# Replace path with digest in the line
|
||||
line = f"{command} @{blob}\n"
|
||||
|
||||
processed_lines.append(line)
|
||||
|
||||
# Combine processed lines back into a single string
|
||||
modelfile_content = '\n'.join(processed_lines)
|
||||
|
||||
url = f"{BASE_URL}/api/create"
|
||||
payload = {"name": model_name, "modelfile": modelfile_content}
|
||||
|
||||
# Making a POST request with the stream parameter set to True to handle streaming responses
|
||||
with requests.post(url, json=payload, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
# Iterating over the response line by line and displaying the status
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
chunk = json.loads(line)
|
||||
if callback:
|
||||
callback(chunk)
|
||||
else:
|
||||
print(f"Status: {chunk.get('status')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
|
||||
# Pull a model from a the model registry. Cancelled pulls are resumed from where they left off, and multiple
|
||||
# calls to will share the same download progress. Use the callback function to override the default handler.
|
||||
def pull(model_name, insecure=False, callback=None):
|
||||
try:
|
||||
url = f"{BASE_URL}/api/pull"
|
||||
payload = {
|
||||
"name": model_name,
|
||||
"insecure": insecure
|
||||
}
|
||||
|
||||
# Making a POST request with the stream parameter set to True to handle streaming responses
|
||||
with requests.post(url, json=payload, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# Iterating over the response line by line and displaying the details
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
# Parsing each line (JSON chunk) and extracting the details
|
||||
chunk = json.loads(line)
|
||||
|
||||
# If a callback function is provided, call it with the chunk
|
||||
if callback:
|
||||
callback(chunk)
|
||||
else:
|
||||
# Print the status message directly to the console
|
||||
print(chunk.get('status', ''), end='', flush=True)
|
||||
|
||||
# If there's layer data, you might also want to print that (adjust as necessary)
|
||||
if 'digest' in chunk:
|
||||
print(f" - Digest: {chunk['digest']}", end='', flush=True)
|
||||
print(f" - Total: {chunk['total']}", end='', flush=True)
|
||||
print(f" - Completed: {chunk['completed']}", end='\n', flush=True)
|
||||
else:
|
||||
print()
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
# Push a model to the model registry. Use the callback function to override the default handler.
|
||||
def push(model_name, insecure=False, callback=None):
|
||||
try:
|
||||
url = f"{BASE_URL}/api/push"
|
||||
payload = {
|
||||
"name": model_name,
|
||||
"insecure": insecure
|
||||
}
|
||||
|
||||
# Making a POST request with the stream parameter set to True to handle streaming responses
|
||||
with requests.post(url, json=payload, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# Iterating over the response line by line and displaying the details
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
# Parsing each line (JSON chunk) and extracting the details
|
||||
chunk = json.loads(line)
|
||||
|
||||
# If a callback function is provided, call it with the chunk
|
||||
if callback:
|
||||
callback(chunk)
|
||||
else:
|
||||
# Print the status message directly to the console
|
||||
print(chunk.get('status', ''), end='', flush=True)
|
||||
|
||||
# If there's layer data, you might also want to print that (adjust as necessary)
|
||||
if 'digest' in chunk:
|
||||
print(f" - Digest: {chunk['digest']}", end='', flush=True)
|
||||
print(f" - Total: {chunk['total']}", end='', flush=True)
|
||||
print(f" - Completed: {chunk['completed']}", end='\n', flush=True)
|
||||
else:
|
||||
print()
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
# List models that are available locally.
|
||||
def list():
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/api/tags")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
models = data.get('models', [])
|
||||
return models
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
# Copy a model. Creates a model with another name from an existing model.
|
||||
def copy(source, destination):
|
||||
try:
|
||||
# Create the JSON payload
|
||||
payload = {
|
||||
"source": source,
|
||||
"destination": destination
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/api/copy", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
# If the request was successful, return a message indicating that the copy was successful
|
||||
return "Copy successful"
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
# Delete a model and its data.
|
||||
def delete(model_name):
|
||||
try:
|
||||
url = f"{BASE_URL}/api/delete"
|
||||
payload = {"name": model_name}
|
||||
response = requests.delete(url, json=payload)
|
||||
response.raise_for_status()
|
||||
return "Delete successful"
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
# Show info about a model.
|
||||
def show(model_name):
|
||||
try:
|
||||
url = f"{BASE_URL}/api/show"
|
||||
payload = {"name": model_name}
|
||||
response = requests.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse the JSON response and return it
|
||||
data = response.json()
|
||||
return data
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
def heartbeat():
|
||||
try:
|
||||
url = f"{BASE_URL}/"
|
||||
response = requests.head(url)
|
||||
response.raise_for_status()
|
||||
return "Ollama is running"
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return "Ollama is not running"
|
70
api/types.go
70
api/types.go
@@ -34,24 +34,26 @@ func (e StatusError) Error() string {
|
||||
type ImageData []byte
|
||||
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Format string `json:"format"`
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Format string `json:"format"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
@@ -126,8 +128,9 @@ type Runner struct {
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
@@ -137,17 +140,30 @@ type EmbeddingResponse struct {
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Path string `json:"path"`
|
||||
Modelfile string `json:"modelfile"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Name is deprecated, see Model
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type DeleteRequest struct {
|
||||
Model string `json:"model"`
|
||||
|
||||
// Name is deprecated, see Model
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ShowRequest struct {
|
||||
Model string `json:"model"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
|
||||
// Name is deprecated, see Model
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
@@ -158,6 +174,7 @@ type ShowResponse struct {
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
type CopyRequest struct {
|
||||
@@ -166,11 +183,14 @@ type CopyRequest struct {
|
||||
}
|
||||
|
||||
type PullRequest struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Name is deprecated, see Model
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ProgressResponse struct {
|
||||
@@ -181,11 +201,14 @@ type ProgressResponse struct {
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Name is deprecated, see Model
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ListResponse struct {
|
||||
@@ -194,6 +217,7 @@ type ListResponse struct {
|
||||
|
||||
type ModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
@@ -216,6 +240,7 @@ type GenerateResponse struct {
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
ParentModel string `json:"parent_model"`
|
||||
Format string `json:"format"`
|
||||
Family string `json:"family"`
|
||||
Families []string `json:"families"`
|
||||
@@ -391,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
case float64:
|
||||
if t < 0 {
|
||||
t = math.MaxFloat64
|
||||
d.Duration = time.Duration(t)
|
||||
} else {
|
||||
d.Duration = time.Duration(t * float64(time.Second))
|
||||
}
|
||||
|
||||
d.Duration = time.Duration(t)
|
||||
case string:
|
||||
d.Duration, err = time.ParseDuration(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.Duration < 0 {
|
||||
mf := math.MaxFloat64
|
||||
d.Duration = time.Duration(mf)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
822
cmd/cmd.go
822
cmd/cmd.go
@@ -17,7 +17,6 @@ import (
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -33,13 +32,10 @@ import (
|
||||
"github.com/jmorganca/ollama/format"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/progress"
|
||||
"github.com/jmorganca/ollama/readline"
|
||||
"github.com/jmorganca/ollama/server"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
type ImageData []byte
|
||||
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
filename, _ := cmd.Flags().GetString("file")
|
||||
filename, err := filepath.Abs(filename)
|
||||
@@ -151,19 +147,68 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
|
||||
// check if the model exists on the server
|
||||
_, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
var statusError api.StatusError
|
||||
switch {
|
||||
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
||||
if err := PullHandler(cmd, args); err != nil {
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
return RunGenerate(cmd, args)
|
||||
interactive := true
|
||||
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
MultiModal: slices.Contains(show.Details.Families, "clip"),
|
||||
ParentModel: show.Details.ParentModel,
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
prompts := args[1:]
|
||||
// prepend stdin to the prompt if provided
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
in, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
}
|
||||
opts.Prompt = strings.Join(prompts, " ")
|
||||
if len(prompts) > 0 {
|
||||
interactive = false
|
||||
}
|
||||
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
if !interactive {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -415,66 +460,139 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
opts := generateOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
Images: []ImageData{},
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
prompts := args[1:]
|
||||
// prepend stdin to the prompt if provided
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
in, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
}
|
||||
opts.Prompt = strings.Join(prompts, " ")
|
||||
if len(prompts) > 0 {
|
||||
interactive = false
|
||||
}
|
||||
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
if !interactive {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
type generateContextKey string
|
||||
|
||||
type generateOptions struct {
|
||||
Model string
|
||||
Prompt string
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Template string
|
||||
Images []ImageData
|
||||
Options map[string]interface{}
|
||||
type runOptions struct {
|
||||
Model string
|
||||
ParentModel string
|
||||
Prompt string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Template string
|
||||
Images []api.ImageData
|
||||
Options map[string]interface{}
|
||||
MultiModal bool
|
||||
}
|
||||
|
||||
func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||
type displayResponseState struct {
|
||||
lineLength int
|
||||
wordBuffer string
|
||||
}
|
||||
|
||||
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
|
||||
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if wordWrap && termWidth >= 10 {
|
||||
for _, ch := range content {
|
||||
if state.lineLength+1 > termWidth-5 {
|
||||
if len(state.wordBuffer) > termWidth-10 {
|
||||
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||
state.wordBuffer = ""
|
||||
state.lineLength = 0
|
||||
continue
|
||||
}
|
||||
|
||||
// backtrack the length of the last word and clear to the end of the line
|
||||
fmt.Printf("\x1b[%dD\x1b[K\n", len(state.wordBuffer))
|
||||
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||
state.lineLength = len(state.wordBuffer) + 1
|
||||
} else {
|
||||
fmt.Print(string(ch))
|
||||
state.lineLength += 1
|
||||
|
||||
switch ch {
|
||||
case ' ':
|
||||
state.wordBuffer = ""
|
||||
case '\n':
|
||||
state.lineLength = 0
|
||||
default:
|
||||
state.wordBuffer += string(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("%s%s", state.wordBuffer, content)
|
||||
if len(state.wordBuffer) > 0 {
|
||||
state.wordBuffer = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.StopAndClear()
|
||||
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var role string
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
p.StopAndClear()
|
||||
|
||||
latest = response
|
||||
|
||||
role = response.Message.Role
|
||||
content := response.Message.Content
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
Format: opts.Format,
|
||||
Options: opts.Options,
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(opts.Messages) > 0 {
|
||||
fmt.Println()
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
verbose, err := cmd.Flags().GetBool("verbose")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if verbose {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -493,11 +611,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||
generateContext = []int{}
|
||||
}
|
||||
|
||||
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
opts.WordWrap = false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
@@ -509,94 +622,44 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var currentLineLength int
|
||||
var wordBuffer string
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
p.StopAndClear()
|
||||
|
||||
latest = response
|
||||
content := response.Response
|
||||
|
||||
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
|
||||
if opts.WordWrap && termWidth >= 10 {
|
||||
for _, ch := range response.Response {
|
||||
if currentLineLength+1 > termWidth-5 {
|
||||
if len(wordBuffer) > termWidth-10 {
|
||||
fmt.Printf("%s%c", wordBuffer, ch)
|
||||
wordBuffer = ""
|
||||
currentLineLength = 0
|
||||
continue
|
||||
}
|
||||
|
||||
// backtrack the length of the last word and clear to the end of the line
|
||||
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
|
||||
fmt.Printf("%s%c", wordBuffer, ch)
|
||||
currentLineLength = len(wordBuffer) + 1
|
||||
} else {
|
||||
fmt.Print(string(ch))
|
||||
currentLineLength += 1
|
||||
|
||||
switch ch {
|
||||
case ' ':
|
||||
wordBuffer = ""
|
||||
case '\n':
|
||||
currentLineLength = 0
|
||||
default:
|
||||
wordBuffer += string(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("%s%s", wordBuffer, response.Response)
|
||||
if len(wordBuffer) > 0 {
|
||||
wordBuffer = ""
|
||||
}
|
||||
}
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
images := make([]api.ImageData, 0)
|
||||
for _, i := range opts.Images {
|
||||
images = append(images, api.ImageData(i))
|
||||
if opts.MultiModal {
|
||||
opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
request := api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
Prompt: opts.Prompt,
|
||||
Context: generateContext,
|
||||
Images: opts.Images,
|
||||
Format: opts.Format,
|
||||
System: opts.System,
|
||||
Template: opts.Template,
|
||||
Options: opts.Options,
|
||||
Images: images,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
case strings.Contains(err.Error(), "unsupported model format"):
|
||||
// pull and retry to see if the model has been updated
|
||||
parts := strings.Split(opts.Model, string(os.PathSeparator))
|
||||
if len(parts) == 1 {
|
||||
// this is a library model, log some info
|
||||
fmt.Fprintln(os.Stderr, "This model is no longer compatible with Ollama. Pulling a new version...")
|
||||
}
|
||||
if err := PullHandler(cmd, []string{opts.Model}); err != nil {
|
||||
fmt.Printf("Error: %s\n", err)
|
||||
return fmt.Errorf("unsupported model, please update this model to gguf format") // relay the original error
|
||||
}
|
||||
// retry
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if opts.Prompt != "" {
|
||||
fmt.Println()
|
||||
fmt.Println()
|
||||
@@ -621,459 +684,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type MultilineState int
|
||||
|
||||
const (
|
||||
MultilineNone MultilineState = iota
|
||||
MultilinePrompt
|
||||
MultilineSystem
|
||||
MultilineTemplate
|
||||
)
|
||||
|
||||
func modelIsMultiModal(cmd *cobra.Command, name string) bool {
|
||||
// get model details
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return false
|
||||
}
|
||||
|
||||
req := api.ShowRequest{Name: name}
|
||||
resp, err := client.Show(cmd.Context(), &req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return slices.Contains(resp.Details.Families, "clip")
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||
multiModal := modelIsMultiModal(cmd, opts.Model)
|
||||
|
||||
// load the model
|
||||
loadOpts := generateOptions{
|
||||
Model: opts.Model,
|
||||
Prompt: "",
|
||||
Images: []ImageData{},
|
||||
}
|
||||
if err := generate(cmd, loadOpts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usage := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageSet := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
|
||||
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
|
||||
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
|
||||
fmt.Fprintln(os.Stderr, " /set history Enable history")
|
||||
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
||||
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
||||
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
|
||||
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
|
||||
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageShortcuts := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available keyboard shortcuts:")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + a Move to the beginning of the line (Home)")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + e Move to the end of the line (End)")
|
||||
fmt.Fprintln(os.Stderr, " Alt + b Move back (left) one word")
|
||||
fmt.Fprintln(os.Stderr, " Alt + f Move forward (right) one word")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + k Delete the sentence after the cursor")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + u Delete the sentence before the cursor")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageShow := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /show license Show model license")
|
||||
fmt.Fprintln(os.Stderr, " /show modelfile Show Modelfile for this model")
|
||||
fmt.Fprintln(os.Stderr, " /show parameters Show parameters for this model")
|
||||
fmt.Fprintln(os.Stderr, " /show system Show system message")
|
||||
fmt.Fprintln(os.Stderr, " /show template Show prompt template")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
// only list out the most common parameters
|
||||
usageParameters := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Parameters:")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter seed <int> Random number seed")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter stop \"<string>\", ... Set the stop parameters")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Print(readline.StartBracketedPaste)
|
||||
defer fmt.Printf(readline.EndBracketedPaste)
|
||||
|
||||
var multiline MultilineState
|
||||
var prompt string
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
fmt.Println()
|
||||
return nil
|
||||
case errors.Is(err, readline.ErrInterrupt):
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
|
||||
scanner.Prompt.UseAlt = false
|
||||
prompt = ""
|
||||
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(prompt, `"""`):
|
||||
// if the prompt so far starts with """ then we're in multiline mode
|
||||
// and we need to keep reading until we find a line that ends with """
|
||||
cut, found := strings.CutSuffix(line, `"""`)
|
||||
prompt += cut
|
||||
|
||||
if !found {
|
||||
prompt += "\n"
|
||||
continue
|
||||
}
|
||||
|
||||
prompt = strings.TrimPrefix(prompt, `"""`)
|
||||
scanner.Prompt.UseAlt = false
|
||||
|
||||
switch multiline {
|
||||
case MultilineSystem:
|
||||
opts.System = prompt
|
||||
prompt = ""
|
||||
fmt.Println("Set system message.")
|
||||
case MultilineTemplate:
|
||||
opts.Template = prompt
|
||||
prompt = ""
|
||||
fmt.Println("Set prompt template.")
|
||||
}
|
||||
multiline = MultilineNone
|
||||
case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
|
||||
scanner.Prompt.UseAlt = true
|
||||
multiline = MultilinePrompt
|
||||
prompt += line + "\n"
|
||||
continue
|
||||
case scanner.Pasting:
|
||||
prompt += line + "\n"
|
||||
continue
|
||||
case strings.HasPrefix(line, "/list"):
|
||||
args := strings.Fields(line)
|
||||
if err := ListHandler(cmd, args[1:]); err != nil {
|
||||
return err
|
||||
}
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "history":
|
||||
scanner.HistoryEnable()
|
||||
case "nohistory":
|
||||
scanner.HistoryDisable()
|
||||
case "wordwrap":
|
||||
opts.WordWrap = true
|
||||
fmt.Println("Set 'wordwrap' mode.")
|
||||
case "nowordwrap":
|
||||
opts.WordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
cmd.Flags().Set("verbose", "true")
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
cmd.Flags().Set("verbose", "false")
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
} else {
|
||||
opts.Format = args[2]
|
||||
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||
}
|
||||
case "noformat":
|
||||
opts.Format = ""
|
||||
fmt.Println("Disabled format.")
|
||||
case "parameter":
|
||||
if len(args) < 4 {
|
||||
usageParameters()
|
||||
continue
|
||||
}
|
||||
var params []string
|
||||
for _, p := range args[3:] {
|
||||
params = append(params, p)
|
||||
}
|
||||
fp, err := api.FormatParams(map[string][]string{args[2]: params})
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't set parameter: %q\n\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", "))
|
||||
opts.Options[args[2]] = fp[args[2]]
|
||||
case "system", "template":
|
||||
if len(args) < 3 {
|
||||
usageSet()
|
||||
continue
|
||||
}
|
||||
line := strings.Join(args[2:], " ")
|
||||
line = strings.TrimPrefix(line, `"""`)
|
||||
if strings.HasPrefix(args[2], `"""`) {
|
||||
cut, found := strings.CutSuffix(line, `"""`)
|
||||
prompt += cut
|
||||
if found {
|
||||
if args[1] == "system" {
|
||||
opts.System = prompt
|
||||
fmt.Println("Set system message.")
|
||||
} else {
|
||||
opts.Template = prompt
|
||||
fmt.Println("Set prompt template.")
|
||||
}
|
||||
prompt = ""
|
||||
} else {
|
||||
prompt = `"""` + prompt + "\n"
|
||||
if args[1] == "system" {
|
||||
multiline = MultilineSystem
|
||||
} else {
|
||||
multiline = MultilineTemplate
|
||||
}
|
||||
scanner.Prompt.UseAlt = true
|
||||
}
|
||||
} else {
|
||||
opts.System = line
|
||||
fmt.Println("Set system message.")
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
usageSet()
|
||||
}
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: opts.Model})
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model")
|
||||
return err
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "license":
|
||||
if resp.License == "" {
|
||||
fmt.Print("No license was specified for this model.\n\n")
|
||||
} else {
|
||||
fmt.Println(resp.License)
|
||||
}
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
if resp.Parameters == "" {
|
||||
fmt.Print("No parameters were specified for this model.\n\n")
|
||||
} else {
|
||||
if len(opts.Options) > 0 {
|
||||
fmt.Println("User defined parameters:")
|
||||
for k, v := range opts.Options {
|
||||
fmt.Printf("%-*s %v\n", 30, k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Println("Model defined parameters:")
|
||||
fmt.Println(resp.Parameters)
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
case opts.System != "":
|
||||
fmt.Println(opts.System + "\n")
|
||||
case resp.System != "":
|
||||
fmt.Println(resp.System + "\n")
|
||||
default:
|
||||
fmt.Print("No system message was specified for this model.\n\n")
|
||||
}
|
||||
case "template":
|
||||
switch {
|
||||
case opts.Template != "":
|
||||
fmt.Println(opts.Template + "\n")
|
||||
case resp.Template != "":
|
||||
fmt.Println(resp.Template)
|
||||
default:
|
||||
fmt.Print("No prompt template was specified for this model.\n\n")
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
usageShow()
|
||||
}
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "set", "/set":
|
||||
usageSet()
|
||||
case "show", "/show":
|
||||
usageShow()
|
||||
case "shortcut", "shortcuts":
|
||||
usageShortcuts()
|
||||
}
|
||||
} else {
|
||||
usage()
|
||||
}
|
||||
case line == "/exit", line == "/bye":
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/"):
|
||||
args := strings.Fields(line)
|
||||
isFile := false
|
||||
|
||||
if multiModal {
|
||||
for _, f := range extractFileNames(line) {
|
||||
if strings.HasPrefix(f, args[0]) {
|
||||
isFile = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isFile {
|
||||
prompt += line
|
||||
} else {
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
|
||||
continue
|
||||
}
|
||||
default:
|
||||
prompt += line
|
||||
}
|
||||
|
||||
if len(prompt) > 0 && multiline == MultilineNone {
|
||||
opts.Prompt = prompt
|
||||
if multiModal {
|
||||
newPrompt, images, err := extractFileData(prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Prompt = newPrompt
|
||||
|
||||
// reset the context if we find another image
|
||||
if len(images) > 0 {
|
||||
opts.Images = images
|
||||
ctx := cmd.Context()
|
||||
ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
|
||||
cmd.SetContext(ctx)
|
||||
}
|
||||
if len(opts.Images) == 0 {
|
||||
fmt.Println("This model requires you to add a jpeg, png, or svg image.")
|
||||
fmt.Println()
|
||||
prompt = ""
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := generate(cmd, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompt = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeFilePath(fp string) string {
|
||||
// Define a map of escaped characters and their replacements
|
||||
replacements := map[string]string{
|
||||
"\\ ": " ", // Escaped space
|
||||
"\\(": "(", // Escaped left parenthesis
|
||||
"\\)": ")", // Escaped right parenthesis
|
||||
"\\[": "[", // Escaped left square bracket
|
||||
"\\]": "]", // Escaped right square bracket
|
||||
"\\{": "{", // Escaped left curly brace
|
||||
"\\}": "}", // Escaped right curly brace
|
||||
"\\$": "$", // Escaped dollar sign
|
||||
"\\&": "&", // Escaped ampersand
|
||||
"\\;": ";", // Escaped semicolon
|
||||
"\\'": "'", // Escaped single quote
|
||||
"\\\\": "\\", // Escaped backslash
|
||||
"\\*": "*", // Escaped asterisk
|
||||
"\\?": "?", // Escaped question mark
|
||||
}
|
||||
|
||||
for escaped, actual := range replacements {
|
||||
fp = strings.ReplaceAll(fp, escaped, actual)
|
||||
}
|
||||
return fp
|
||||
}
|
||||
|
||||
func extractFileNames(input string) []string {
|
||||
// Regex to match file paths starting with / or ./ and include escaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
regexPattern := `(?:\./|/)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
return re.FindAllString(input, -1)
|
||||
}
|
||||
|
||||
func extractFileData(input string) (string, []ImageData, error) {
|
||||
filePaths := extractFileNames(input)
|
||||
var imgs []ImageData
|
||||
|
||||
for _, fp := range filePaths {
|
||||
nfp := normalizeFilePath(fp)
|
||||
data, err := getImageData(nfp)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Couldn't process image: %q\n", err)
|
||||
return "", imgs, err
|
||||
}
|
||||
fmt.Printf("Added image '%s'\n", nfp)
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
return input, imgs, nil
|
||||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
|
||||
if err != nil {
|
||||
@@ -1095,50 +705,6 @@ func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
return server.Serve(ln)
|
||||
}
|
||||
|
||||
func getImageData(filePath string) ([]byte, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := make([]byte, 512)
|
||||
_, err = file.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
}
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the file size exceeds 100MB
|
||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
|
||||
if info.Size() > maxSize {
|
||||
return nil, fmt.Errorf("file size exceeds maximum limit (100MB)")
|
||||
}
|
||||
|
||||
buf = make([]byte, info.Size())
|
||||
_, err = file.Seek(0, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.ReadFull(file, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func initializeKeypair() error {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
|
656
cmd/interactive.go
Normal file
656
cmd/interactive.go
Normal file
@@ -0,0 +1,656 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/progress"
|
||||
"github.com/jmorganca/ollama/readline"
|
||||
)
|
||||
|
||||
type MultilineState int
|
||||
|
||||
const (
|
||||
MultilineNone MultilineState = iota
|
||||
MultilinePrompt
|
||||
MultilineSystem
|
||||
MultilineTemplate
|
||||
)
|
||||
|
||||
func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.StopAndClear()
|
||||
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
showReq := api.ShowRequest{Name: opts.Model}
|
||||
showResp, err := client.Show(cmd.Context(), &showReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
|
||||
opts.ParentModel = showResp.Details.ParentModel
|
||||
|
||||
if len(showResp.Messages) > 0 {
|
||||
opts.Messages = append(opts.Messages, showResp.Messages...)
|
||||
}
|
||||
|
||||
chatReq := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: []api.Message{},
|
||||
}
|
||||
err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
|
||||
p.StopAndClear()
|
||||
if len(opts.Messages) > 0 {
|
||||
for _, msg := range opts.Messages {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
fmt.Printf(">>> %s\n", msg.Content)
|
||||
case "assistant":
|
||||
state := &displayResponseState{}
|
||||
displayResponse(msg.Content, opts.WordWrap, state)
|
||||
fmt.Println()
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
opts.Messages = make([]api.Message, 0)
|
||||
|
||||
err := loadModel(cmd, &opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usage := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
||||
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
if opts.MultiModal {
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageSet := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
|
||||
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
|
||||
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
|
||||
fmt.Fprintln(os.Stderr, " /set history Enable history")
|
||||
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
||||
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
||||
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
|
||||
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
|
||||
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageShortcuts := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available keyboard shortcuts:")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + a Move to the beginning of the line (Home)")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + e Move to the end of the line (End)")
|
||||
fmt.Fprintln(os.Stderr, " Alt + b Move back (left) one word")
|
||||
fmt.Fprintln(os.Stderr, " Alt + f Move forward (right) one word")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + k Delete the sentence after the cursor")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + u Delete the sentence before the cursor")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageShow := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /show info Show details for this model")
|
||||
fmt.Fprintln(os.Stderr, " /show license Show model license")
|
||||
fmt.Fprintln(os.Stderr, " /show modelfile Show Modelfile for this model")
|
||||
fmt.Fprintln(os.Stderr, " /show parameters Show parameters for this model")
|
||||
fmt.Fprintln(os.Stderr, " /show system Show system message")
|
||||
fmt.Fprintln(os.Stderr, " /show template Show prompt template")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
// only list out the most common parameters
|
||||
usageParameters := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Parameters:")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter seed <int> Random number seed")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter stop \"<string>\", ... Set the stop parameters")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Print(readline.StartBracketedPaste)
|
||||
defer fmt.Printf(readline.EndBracketedPaste)
|
||||
|
||||
var sb strings.Builder
|
||||
var multiline MultilineState
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
fmt.Println()
|
||||
return nil
|
||||
case errors.Is(err, readline.ErrInterrupt):
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case multiline != MultilineNone:
|
||||
// check if there's a multiline terminating string
|
||||
before, ok := strings.CutSuffix(line, `"""`)
|
||||
sb.WriteString(before)
|
||||
if !ok {
|
||||
fmt.Fprintln(&sb)
|
||||
continue
|
||||
}
|
||||
|
||||
switch multiline {
|
||||
case MultilineSystem:
|
||||
opts.System = sb.String()
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
case MultilineTemplate:
|
||||
opts.Template = sb.String()
|
||||
fmt.Println("Set prompt template.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
multiline = MultilineNone
|
||||
scanner.Prompt.UseAlt = false
|
||||
case strings.HasPrefix(line, `"""`):
|
||||
line := strings.TrimPrefix(line, `"""`)
|
||||
line, ok := strings.CutSuffix(line, `"""`)
|
||||
sb.WriteString(line)
|
||||
if !ok {
|
||||
// no multiline terminating string; need more input
|
||||
fmt.Fprintln(&sb)
|
||||
multiline = MultilinePrompt
|
||||
scanner.Prompt.UseAlt = true
|
||||
}
|
||||
case scanner.Pasting:
|
||||
fmt.Fprintln(&sb, line)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/list"):
|
||||
args := strings.Fields(line)
|
||||
if err := ListHandler(cmd, args[1:]); err != nil {
|
||||
return err
|
||||
}
|
||||
case strings.HasPrefix(line, "/load"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage:\n /load <modelname>")
|
||||
continue
|
||||
}
|
||||
opts.Model = args[1]
|
||||
opts.Messages = []api.Message{}
|
||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||
if err := loadModel(cmd, &opts); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/save"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage:\n /save <modelname>")
|
||||
continue
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
|
||||
req := &api.CreateRequest{
|
||||
Name: args[1],
|
||||
Modelfile: buildModelfile(opts),
|
||||
}
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't save model")
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Created new model '%s'\n", args[1])
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "history":
|
||||
scanner.HistoryEnable()
|
||||
case "nohistory":
|
||||
scanner.HistoryDisable()
|
||||
case "wordwrap":
|
||||
opts.WordWrap = true
|
||||
fmt.Println("Set 'wordwrap' mode.")
|
||||
case "nowordwrap":
|
||||
opts.WordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
cmd.Flags().Set("verbose", "true")
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
cmd.Flags().Set("verbose", "false")
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
} else {
|
||||
opts.Format = args[2]
|
||||
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||
}
|
||||
case "noformat":
|
||||
opts.Format = ""
|
||||
fmt.Println("Disabled format.")
|
||||
case "parameter":
|
||||
if len(args) < 4 {
|
||||
usageParameters()
|
||||
continue
|
||||
}
|
||||
params := args[3:]
|
||||
fp, err := api.FormatParams(map[string][]string{args[2]: params})
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't set parameter: %q\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||
opts.Options[args[2]] = fp[args[2]]
|
||||
case "system", "template":
|
||||
if len(args) < 3 {
|
||||
usageSet()
|
||||
continue
|
||||
}
|
||||
|
||||
if args[1] == "system" {
|
||||
multiline = MultilineSystem
|
||||
} else if args[1] == "template" {
|
||||
multiline = MultilineTemplate
|
||||
}
|
||||
|
||||
line := strings.Join(args[2:], " ")
|
||||
line, ok := strings.CutPrefix(line, `"""`)
|
||||
if !ok {
|
||||
multiline = MultilineNone
|
||||
} else {
|
||||
// only cut suffix if the line is multiline
|
||||
line, ok = strings.CutSuffix(line, `"""`)
|
||||
if ok {
|
||||
multiline = MultilineNone
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(line)
|
||||
if multiline != MultilineNone {
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
if args[1] == "system" {
|
||||
opts.System = sb.String()
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
} else if args[1] == "template" {
|
||||
opts.Template = sb.String()
|
||||
fmt.Println("Set prompt template.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
usageSet()
|
||||
}
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: opts.Model,
|
||||
System: opts.System,
|
||||
Template: opts.Template,
|
||||
Options: opts.Options,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model")
|
||||
return err
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "info":
|
||||
fmt.Println("Model details:")
|
||||
if len(resp.Details.Families) > 0 {
|
||||
fmt.Printf("Family %s\n", strings.Join(resp.Details.Families, ", "))
|
||||
} else if resp.Details.Family != "" {
|
||||
fmt.Printf("Family %s\n", resp.Details.Family)
|
||||
}
|
||||
fmt.Printf("Parameter Size %s\n", resp.Details.ParameterSize)
|
||||
fmt.Printf("Quantization Level %s\n", resp.Details.QuantizationLevel)
|
||||
fmt.Println("")
|
||||
case "license":
|
||||
if resp.License == "" {
|
||||
fmt.Println("No license was specified for this model.")
|
||||
} else {
|
||||
fmt.Println(resp.License)
|
||||
}
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println("No parameters were specified for this model.")
|
||||
} else {
|
||||
if len(opts.Options) > 0 {
|
||||
fmt.Println("User defined parameters:")
|
||||
for k, v := range opts.Options {
|
||||
fmt.Printf("%-*s %v\n", 30, k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Println("Model defined parameters:")
|
||||
fmt.Println(resp.Parameters)
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
case opts.System != "":
|
||||
fmt.Println(opts.System + "\n")
|
||||
case resp.System != "":
|
||||
fmt.Println(resp.System + "\n")
|
||||
default:
|
||||
fmt.Println("No system message was specified for this model.")
|
||||
}
|
||||
case "template":
|
||||
switch {
|
||||
case opts.Template != "":
|
||||
fmt.Println(opts.Template + "\n")
|
||||
case resp.Template != "":
|
||||
fmt.Println(resp.Template)
|
||||
default:
|
||||
fmt.Println("No prompt template was specified for this model.")
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
usageShow()
|
||||
}
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "set", "/set":
|
||||
usageSet()
|
||||
case "show", "/show":
|
||||
usageShow()
|
||||
case "shortcut", "shortcuts":
|
||||
usageShortcuts()
|
||||
}
|
||||
} else {
|
||||
usage()
|
||||
}
|
||||
case line == "/exit", line == "/bye":
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/"):
|
||||
args := strings.Fields(line)
|
||||
isFile := false
|
||||
|
||||
if opts.MultiModal {
|
||||
for _, f := range extractFileNames(line) {
|
||||
if strings.HasPrefix(f, args[0]) {
|
||||
isFile = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isFile {
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
|
||||
continue
|
||||
}
|
||||
|
||||
sb.WriteString(line)
|
||||
default:
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
if sb.Len() > 0 && multiline == MultilineNone {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
|
||||
if opts.MultiModal {
|
||||
msg, images, err := extractFileData(sb.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// clear all previous images for better responses
|
||||
if len(images) > 0 {
|
||||
for i := range opts.Messages {
|
||||
opts.Messages[i].Images = nil
|
||||
}
|
||||
}
|
||||
|
||||
newMessage.Content = msg
|
||||
newMessage.Images = images
|
||||
}
|
||||
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
opts.Messages = append(opts.Messages, *assistant)
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildModelfile(opts runOptions) string {
|
||||
var mf strings.Builder
|
||||
model := opts.ParentModel
|
||||
if model == "" {
|
||||
model = opts.Model
|
||||
}
|
||||
fmt.Fprintf(&mf, "FROM %s\n", model)
|
||||
if opts.System != "" {
|
||||
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
|
||||
}
|
||||
|
||||
if opts.Template != "" {
|
||||
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
|
||||
}
|
||||
|
||||
keys := make([]string, 0)
|
||||
for k := range opts.Options {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k])
|
||||
}
|
||||
fmt.Fprintln(&mf)
|
||||
|
||||
for _, msg := range opts.Messages {
|
||||
fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content)
|
||||
}
|
||||
|
||||
return mf.String()
|
||||
}
|
||||
|
||||
func normalizeFilePath(fp string) string {
|
||||
// Define a map of escaped characters and their replacements
|
||||
replacements := map[string]string{
|
||||
"\\ ": " ", // Escaped space
|
||||
"\\(": "(", // Escaped left parenthesis
|
||||
"\\)": ")", // Escaped right parenthesis
|
||||
"\\[": "[", // Escaped left square bracket
|
||||
"\\]": "]", // Escaped right square bracket
|
||||
"\\{": "{", // Escaped left curly brace
|
||||
"\\}": "}", // Escaped right curly brace
|
||||
"\\$": "$", // Escaped dollar sign
|
||||
"\\&": "&", // Escaped ampersand
|
||||
"\\;": ";", // Escaped semicolon
|
||||
"\\'": "'", // Escaped single quote
|
||||
"\\\\": "\\", // Escaped backslash
|
||||
"\\*": "*", // Escaped asterisk
|
||||
"\\?": "?", // Escaped question mark
|
||||
}
|
||||
|
||||
for escaped, actual := range replacements {
|
||||
fp = strings.ReplaceAll(fp, escaped, actual)
|
||||
}
|
||||
return fp
|
||||
}
|
||||
|
||||
func extractFileNames(input string) []string {
|
||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
return re.FindAllString(input, -1)
|
||||
}
|
||||
|
||||
func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
filePaths := extractFileNames(input)
|
||||
var imgs []api.ImageData
|
||||
|
||||
for _, fp := range filePaths {
|
||||
nfp := normalizeFilePath(fp)
|
||||
data, err := getImageData(nfp)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
|
||||
return "", imgs, err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
return input, imgs, nil
|
||||
}
|
||||
|
||||
func getImageData(filePath string) ([]byte, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := make([]byte, 512)
|
||||
_, err = file.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
}
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the file size exceeds 100MB
|
||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
|
||||
if info.Size() > maxSize {
|
||||
return nil, fmt.Errorf("file size exceeds maximum limit (100MB)")
|
||||
}
|
||||
|
||||
buf = make([]byte, info.Size())
|
||||
_, err = file.Seek(0, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.ReadFull(file, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
116
cmd/interactive_test.go
Normal file
116
cmd/interactive_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func TestExtractFilenames(t *testing.T) {
|
||||
// Unix style paths
|
||||
input := ` some preamble
|
||||
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2
|
||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.svg`
|
||||
res := extractFileNames(input)
|
||||
assert.Len(t, res, 5)
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[1], "two.jpg")
|
||||
assert.Contains(t, res[2], "three.jpeg")
|
||||
assert.Contains(t, res[3], "four.png")
|
||||
assert.Contains(t, res[4], "five.svg")
|
||||
assert.NotContains(t, res[4], '"')
|
||||
assert.NotContains(t, res, "inbtween")
|
||||
|
||||
// Windows style paths
|
||||
input = ` some preamble
|
||||
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
|
||||
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
|
||||
./relative\ path/five.svg inbetween5 "./relative with/spaces/six.png inbetween6
|
||||
d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.svg some ending
|
||||
`
|
||||
res = extractFileNames(input)
|
||||
assert.Len(t, res, 10)
|
||||
assert.NotContains(t, res, "inbtween")
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[0], "c:")
|
||||
assert.Contains(t, res[1], "two.jpg")
|
||||
assert.Contains(t, res[1], "c:")
|
||||
assert.Contains(t, res[2], "three.jpeg")
|
||||
assert.Contains(t, res[3], "four.png")
|
||||
assert.Contains(t, res[4], "five.svg")
|
||||
assert.Contains(t, res[5], "six.png")
|
||||
assert.Contains(t, res[6], "seven.svg")
|
||||
assert.Contains(t, res[6], "d:")
|
||||
assert.Contains(t, res[7], "eight.png")
|
||||
assert.Contains(t, res[7], "c:")
|
||||
assert.Contains(t, res[8], "nine.png")
|
||||
assert.Contains(t, res[8], "d:")
|
||||
assert.Contains(t, res[9], "ten.svg")
|
||||
assert.Contains(t, res[9], "E:")
|
||||
}
|
||||
|
||||
func TestModelfileBuilder(t *testing.T) {
|
||||
opts := runOptions{
|
||||
Model: "hork",
|
||||
System: "You are part horse and part shark, but all hork. Do horklike things",
|
||||
Template: "This is a template.",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hey there hork!"},
|
||||
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
|
||||
},
|
||||
Options: map[string]interface{}{},
|
||||
}
|
||||
|
||||
opts.Options["temperature"] = 0.9
|
||||
opts.Options["seed"] = 42
|
||||
opts.Options["penalize_newline"] = false
|
||||
opts.Options["stop"] = []string{"hi", "there"}
|
||||
|
||||
mf := buildModelfile(opts)
|
||||
expectedModelfile := `FROM {{.Model}}
|
||||
SYSTEM """{{.System}}"""
|
||||
TEMPLATE """{{.Template}}"""
|
||||
PARAMETER penalize_newline false
|
||||
PARAMETER seed 42
|
||||
PARAMETER stop [hi there]
|
||||
PARAMETER temperature 0.9
|
||||
|
||||
MESSAGE user """Hey there hork!"""
|
||||
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
||||
`
|
||||
|
||||
tmpl, err := template.New("").Parse(expectedModelfile)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, opts)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, buf.String(), mf)
|
||||
|
||||
opts.ParentModel = "horseshark"
|
||||
mf = buildModelfile(opts)
|
||||
expectedModelfile = `FROM {{.ParentModel}}
|
||||
SYSTEM """{{.System}}"""
|
||||
TEMPLATE """{{.Template}}"""
|
||||
PARAMETER penalize_newline false
|
||||
PARAMETER seed 42
|
||||
PARAMETER stop [hi there]
|
||||
PARAMETER temperature 0.9
|
||||
|
||||
MESSAGE user """Hey there hork!"""
|
||||
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
||||
`
|
||||
|
||||
tmpl, err = template.New("").Parse(expectedModelfile)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var parentBuf bytes.Buffer
|
||||
err = tmpl.Execute(&parentBuf, opts)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, parentBuf.String(), mf)
|
||||
}
|
@@ -12,7 +12,7 @@ Import models using source model weights found on Hugging Face and similar sites
|
||||
|
||||
Installing on Linux in most cases is easy using the script on Ollama.ai. To get more detail about the install, including CUDA drivers, see the **[Linux Documentation](./linux.md)**.
|
||||
|
||||
Many of our users like the flexibility of using our official Docker Image. Learn more about using Docker with Ollama using the **[Docker Documentation](./docker.md)**.
|
||||
Many of our users like the flexibility of using our official Docker Image. Learn more about using Docker with Ollama using the **[Docker Documentation](https://hub.docker.com/r/ollama/ollama)**.
|
||||
|
||||
It is easy to install on Linux and Mac, but many users will choose to build Ollama on their own. To do this, refer to the **[Development Documentation](./development.md)**.
|
||||
|
||||
|
@@ -409,7 +409,7 @@ A stream of JSON objects is returned:
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"message": {
|
||||
"role": "assisant",
|
||||
"role": "assistant",
|
||||
"content": "The",
|
||||
"images": null
|
||||
},
|
||||
@@ -505,7 +505,7 @@ A stream of JSON objects is returned:
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"message": {
|
||||
"role": "assisant",
|
||||
"role": "assistant",
|
||||
"content": "The"
|
||||
},
|
||||
"done": false
|
||||
|
@@ -1,13 +1,9 @@
|
||||
# Development
|
||||
|
||||
- Install cmake or (optionally, required tools for GPUs)
|
||||
- run `go generate ./...`
|
||||
- run `go build .`
|
||||
|
||||
Install required tools:
|
||||
|
||||
- cmake version 3.24 or higher
|
||||
- go version 1.20 or higher
|
||||
- go version 1.21 or higher
|
||||
- gcc version 11.4.0 or higher
|
||||
|
||||
```bash
|
||||
@@ -17,7 +13,11 @@ brew install go cmake gcc
|
||||
Optionally enable debugging and more verbose logging:
|
||||
|
||||
```bash
|
||||
# At build time
|
||||
export CGO_CFLAGS="-g"
|
||||
|
||||
# At runtime
|
||||
export OLLAMA_DEBUG=1
|
||||
```
|
||||
|
||||
Get the required libraries and build the native LLM code:
|
||||
@@ -38,37 +38,101 @@ Now you can run `ollama`:
|
||||
./ollama
|
||||
```
|
||||
|
||||
## Building on Linux with GPU support
|
||||
### Linux
|
||||
|
||||
#### Linux CUDA (NVIDIA)
|
||||
|
||||
### Linux/Windows CUDA (NVIDIA)
|
||||
*Your operating system distribution may already have packages for NVIDIA CUDA. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!*
|
||||
|
||||
Install `cmake` and `golang` as well as [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) development and runtime packages.
|
||||
Install `cmake` and `golang` as well as [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads)
|
||||
development and runtime packages.
|
||||
|
||||
Typically the build scripts will auto-detect CUDA, however, if your Linux distro
|
||||
or installation approach uses unusual paths, you can specify the location by
|
||||
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
|
||||
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
|
||||
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
|
||||
|
||||
Then generate dependencies:
|
||||
|
||||
```
|
||||
go generate ./...
|
||||
```
|
||||
|
||||
Then build the binary:
|
||||
|
||||
```
|
||||
go build .
|
||||
```
|
||||
|
||||
### Linux ROCm (AMD)
|
||||
#### Linux ROCm (AMD)
|
||||
|
||||
*Your operating system distribution may already have packages for AMD ROCm and CLBlast. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!*
|
||||
|
||||
Install [CLBlast](https://github.com/CNugteren/CLBlast/blob/master/doc/installation.md) and [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html) developement packages first, as well as `cmake` and `golang`.
|
||||
Adjust the paths below (correct for Arch) as appropriate for your distributions install locations and generate dependencies:
|
||||
|
||||
Typically the build scripts will auto-detect ROCm, however, if your Linux distro
|
||||
or installation approach uses unusual paths, you can specify the location by
|
||||
specifying an environment variable `ROCM_PATH` to the location of the ROCm
|
||||
install (typically `/opt/rocm`), and `CLBlast_DIR` to the location of the
|
||||
CLBlast install (typically `/usr/lib/cmake/CLBlast`). You can also customize
|
||||
the AMD GPU targets by setting AMDGPU_TARGETS (e.g. `AMDGPU_TARGETS="gfx1101;gfx1102"`)
|
||||
|
||||
```
|
||||
CLBlast_DIR=/usr/lib/cmake/CLBlast ROCM_PATH=/opt/rocm go generate ./...
|
||||
go generate ./...
|
||||
```
|
||||
|
||||
Then build the binary:
|
||||
|
||||
```
|
||||
go build .
|
||||
```
|
||||
|
||||
ROCm requires elevated privileges to access the GPU at runtime. On most distros you can add your user account to the `render` group, or run as root.
|
||||
|
||||
## Containerized Build
|
||||
#### Advanced CPU Settings
|
||||
|
||||
If you have Docker available, you can build linux binaries with `./scripts/build_linux.sh` which has the CUDA and ROCm dependencies included.
|
||||
By default, running `go generate ./...` will compile a few different variations
|
||||
of the LLM library based on common CPU families and vector math capabilities,
|
||||
including a lowest-common-denominator which should run on almost any 64 bit CPU
|
||||
somewhat slowly. At runtime, Ollama will auto-detect the optimal variation to
|
||||
load. If you would like to build a CPU-based build customized for your
|
||||
processor, you can set `OLLAMA_CUSTOM_CPU_DEFS` to the llama.cpp flags you would
|
||||
like to use. For example, to compile an optimized binary for an Intel i9-9880H,
|
||||
you might use:
|
||||
|
||||
```
|
||||
OLLAMA_CUSTOM_CPU_DEFS="-DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_F16C=on -DLLAMA_FMA=on" go generate ./...
|
||||
go build .
|
||||
```
|
||||
|
||||
#### Containerized Linux Build
|
||||
|
||||
If you have Docker available, you can build linux binaries with `./scripts/build_linux.sh` which has the CUDA and ROCm dependencies included. The resulting binary is placed in `./dist`
|
||||
|
||||
|
||||
### Windows
|
||||
|
||||
Note: The windows build for Ollama is still under development.
|
||||
|
||||
Install required tools:
|
||||
|
||||
- MSVC toolchain - C/C++ and cmake as minimal requirements
|
||||
- go version 1.21 or higher
|
||||
- MinGW (pick one variant) with GCC.
|
||||
- <https://www.mingw-w64.org/>
|
||||
- <https://www.msys2.org/>
|
||||
|
||||
```powershell
|
||||
$env:CGO_ENABLED="1"
|
||||
|
||||
go generate ./...
|
||||
|
||||
go build .
|
||||
```
|
||||
|
||||
#### Windows CUDA (NVIDIA)
|
||||
|
||||
In addition to the common Windows development tools described above, install:
|
||||
|
||||
- [NVIDIA CUDA](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html)
|
||||
|
59
docs/faq.md
59
docs/faq.md
@@ -8,35 +8,38 @@ To upgrade Ollama, run the installation process again. On the Mac, click the Oll
|
||||
|
||||
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
|
||||
|
||||
## How do I use Ollama server environment variables on Mac
|
||||
## How do I configure Ollama server?
|
||||
|
||||
On macOS, Ollama runs in the background and is managed by the menubar app. If adding environment variables, Ollama will need to be run manually.
|
||||
Ollama server can be configured with environment variables.
|
||||
|
||||
1. Click the menubar icon for Ollama and choose **Quit Ollama**.
|
||||
2. Open a new terminal window and run the following command (this example uses `OLLAMA_HOST` with an IP address of `123.1.1.1`):
|
||||
### Setting environment variables on Mac
|
||||
|
||||
```bash
|
||||
OLLAMA_HOST=123.1.1.1 ollama serve
|
||||
```
|
||||
If Ollama is run as a macOS application, environment variables should be set using `launchctl`:
|
||||
|
||||
## How do I use Ollama server environment variables on Linux?
|
||||
1. For each environment variable, call `launchctl setenv`.
|
||||
|
||||
If Ollama is installed with the install script, a systemd service was created, running as the Ollama user. To add an environment variable, such as OLLAMA_HOST, follow these steps:
|
||||
```bash
|
||||
launchctl setenv OLLAMA_HOST "0.0.0.0"
|
||||
```
|
||||
|
||||
1. Create a `systemd` drop-in directory and add a config file. This is only needed once.
|
||||
2. Restart Ollama application.
|
||||
|
||||
```bash
|
||||
mkdir -p /etc/systemd/system/ollama.service.d
|
||||
echo '[Service]' >>/etc/systemd/system/ollama.service.d/environment.conf
|
||||
```
|
||||
### Setting environment variables on Linux
|
||||
|
||||
2. For each environment variable, add it to the config file:
|
||||
If Ollama is run as a systemd service, environment variables should be set using `systemctl`:
|
||||
|
||||
```bash
|
||||
echo 'Environment="OLLAMA_HOST=0.0.0.0:11434"' >>/etc/systemd/system/ollama.service.d/environment.conf
|
||||
```
|
||||
1. Edit the systemd service by calling `systemctl edit ollama.service`. This will open an editor.
|
||||
|
||||
3. Reload `systemd` and restart Ollama:
|
||||
2. For each environment variable, add a line `Environment` under section `[Service]`:
|
||||
|
||||
```ini
|
||||
[Service]
|
||||
Environment="OLLAMA_HOST=0.0.0.0"
|
||||
```
|
||||
|
||||
3. Save and exit.
|
||||
|
||||
4. Reload `systemd` and restart Ollama:
|
||||
|
||||
```bash
|
||||
systemctl daemon-reload
|
||||
@@ -45,26 +48,26 @@ If Ollama is installed with the install script, a systemd service was created, r
|
||||
|
||||
## How can I expose Ollama on my network?
|
||||
|
||||
Ollama binds to 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable. Refer to the section above for how to use environment variables on your platform.
|
||||
Ollama binds 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable.
|
||||
|
||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||
|
||||
## How can I allow additional web origins to access Ollama?
|
||||
|
||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Add additional origins with the `OLLAMA_ORIGINS` environment variable. For example, to add all ports on 192.168.1.1 and https://example.com, use:
|
||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
||||
|
||||
```shell
|
||||
OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com
|
||||
```
|
||||
|
||||
Refer to the section above for how to use environment variables on your platform.
|
||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||
|
||||
## Where are models stored?
|
||||
|
||||
- macOS: `~/.ollama/models`.
|
||||
- Linux: `/usr/share/ollama/.ollama/models`
|
||||
|
||||
## How do I set them to a different location?
|
||||
### How do I set them to a different location?
|
||||
|
||||
If a different directory needs to be used, set the environment variable `OLLAMA_MODELS` to the chosen directory. Refer to the section above for how to use environment variables on your platform.
|
||||
If a different directory needs to be used, set the environment variable `OLLAMA_MODELS` to the chosen directory.
|
||||
|
||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||
|
||||
## Does Ollama send my prompts and answers back to Ollama.ai to use in any way?
|
||||
|
||||
|
@@ -109,8 +109,9 @@ Remove the ollama binary from your bin directory (either `/usr/local/bin`, `/usr
|
||||
sudo rm $(which ollama)
|
||||
```
|
||||
|
||||
Remove the downloaded models and Ollama service user:
|
||||
Remove the downloaded models and Ollama service user and group:
|
||||
```bash
|
||||
sudo rm -r /usr/share/ollama
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
```
|
||||
|
@@ -19,6 +19,7 @@ A model file is the blueprint to create and share models with Ollama.
|
||||
- [SYSTEM](#system)
|
||||
- [ADAPTER](#adapter)
|
||||
- [LICENSE](#license)
|
||||
- [MESSAGE](#message)
|
||||
- [Notes](#notes)
|
||||
|
||||
## Format
|
||||
@@ -38,6 +39,7 @@ INSTRUCTION arguments
|
||||
| [`SYSTEM`](#system) | Specifies the system message that will be set in the template. |
|
||||
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
|
||||
| [`LICENSE`](#license) | Specifies the legal license. |
|
||||
| [`MESSAGE`](#message) | Specify message history. |
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -156,11 +158,12 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
#### Template Variables
|
||||
|
||||
| Variable | Description |
|
||||
| --------------- | ------------------------------------------------------------------------------------------------------------- |
|
||||
| `{{ .System }}` | The system message used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
|
||||
| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
|
||||
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
|
||||
| Variable | Description |
|
||||
| ----------------- | ------------------------------------------------------------------------------------------------------------- |
|
||||
| `{{ .System }}` | The system message used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
|
||||
| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
|
||||
| `{{ .Response }}` | The response from the LLM, if not specified response is appended to the end of the template. |
|
||||
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
|
||||
|
||||
```modelfile
|
||||
TEMPLATE """
|
||||
@@ -204,6 +207,19 @@ LICENSE """
|
||||
"""
|
||||
```
|
||||
|
||||
### MESSAGE
|
||||
|
||||
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding:
|
||||
|
||||
```modelfile
|
||||
MESSAGE user Is Toronto in Canada?
|
||||
MESSAGE assistant yes
|
||||
MESSAGE user Is Sacramento in Canada?
|
||||
MESSAGE assistant no
|
||||
MESSAGE user Is Ontario in Canada?
|
||||
MESSAGE assistant yes
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
||||
|
@@ -12,11 +12,49 @@ On Linux systems with systemd, the logs can be found with this command:
|
||||
journalctl -u ollama
|
||||
```
|
||||
|
||||
When you run Ollama in a container, the logs go to stdout/stderr in the container:
|
||||
|
||||
```shell
|
||||
docker logs <container-name>
|
||||
```
|
||||
(Use `docker ps` to find the container name)
|
||||
|
||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||
|
||||
Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
||||
|
||||
## LLM libraries
|
||||
|
||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU
|
||||
vector features. Ollama tries to pick the best one based on the capabilities of
|
||||
your system. If this autodetection has problems, or you run into other problems
|
||||
(e.g. crashes in your GPU) you can workaround this by forcing a specific LLM
|
||||
library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest
|
||||
but most compatible is `cpu`. Rosetta emulation under MacOS will work with the
|
||||
`cpu` library.
|
||||
|
||||
In the server log, you will see a message that looks something like this (varies
|
||||
from release to release):
|
||||
|
||||
```
|
||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
|
||||
```
|
||||
|
||||
**Experimental LLM Library Override**
|
||||
|
||||
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass
|
||||
autodetection, so for example, if you have a CUDA card, but want to force the
|
||||
CPU LLM library with AVX2 vector support, use:
|
||||
|
||||
```
|
||||
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
|
||||
```
|
||||
|
||||
You can see what features your CPU has with the following.
|
||||
```
|
||||
cat /proc/cpuinfo| grep flags | head -1
|
||||
```
|
||||
|
||||
## Known issues
|
||||
|
||||
|
||||
* `signal: illegal instruction (core dumped)`: Ollama requires AVX support from the CPU. This was introduced in 2011 and CPUs started offering it in 2012. CPUs from before that and some lower end CPUs after that may not have AVX support and thus are not supported by Ollama. Some users have had luck with building Ollama on their machines disabling the need for AVX.
|
||||
* N/A
|
4
go.mod
4
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/jmorganca/ollama
|
||||
|
||||
go 1.20
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
@@ -45,7 +45,7 @@ require (
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.org/x/term v0.13.0
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
|
21
gpu/cpu_common.go
Normal file
21
gpu/cpu_common.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package gpu
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
func GetCPUVariant() string {
|
||||
if cpu.X86.HasAVX2 {
|
||||
slog.Info("CPU has AVX2")
|
||||
return "avx2"
|
||||
}
|
||||
if cpu.X86.HasAVX {
|
||||
slog.Info("CPU has AVX")
|
||||
return "avx"
|
||||
}
|
||||
slog.Info("CPU does not have vector extensions")
|
||||
// else LCD
|
||||
return ""
|
||||
}
|
300
gpu/gpu.go
300
gpu/gpu.go
@@ -12,12 +12,14 @@ package gpu
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
type handles struct {
|
||||
@@ -28,31 +30,86 @@ type handles struct {
|
||||
var gpuMutex sync.Mutex
|
||||
var gpuHandles *handles = nil
|
||||
|
||||
// With our current CUDA compile flags, older than 5.0 will not work properly
|
||||
var CudaComputeMin = [2]C.int{5, 0}
|
||||
|
||||
// Possible locations for the nvidia-ml library
|
||||
var CudaLinuxGlobs = []string{
|
||||
"/usr/local/cuda/lib64/libnvidia-ml.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
|
||||
"/usr/lib/wsl/lib/libnvidia-ml.so*",
|
||||
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
|
||||
"/opt/cuda/lib64/libnvidia-ml.so*",
|
||||
"/usr/lib*/libnvidia-ml.so*",
|
||||
"/usr/local/lib*/libnvidia-ml.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
|
||||
|
||||
// TODO: are these stubs ever valid?
|
||||
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
|
||||
}
|
||||
|
||||
var CudaWindowsGlobs = []string{
|
||||
"c:\\Windows\\System32\\nvml.dll",
|
||||
}
|
||||
|
||||
var RocmLinuxGlobs = []string{
|
||||
"/opt/rocm*/lib*/librocm_smi64.so*",
|
||||
}
|
||||
|
||||
var RocmWindowsGlobs = []string{
|
||||
"c:\\Windows\\System32\\rocm_smi64.dll",
|
||||
}
|
||||
|
||||
// Note: gpuMutex must already be held
|
||||
func initGPUHandles() {
|
||||
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
||||
log.Printf("Detecting GPU type")
|
||||
gpuHandles = &handles{nil, nil}
|
||||
var resp C.cuda_init_resp_t
|
||||
C.cuda_init(&resp)
|
||||
if resp.err != nil {
|
||||
log.Printf("CUDA not detected: %s", C.GoString(resp.err))
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
|
||||
var resp C.rocm_init_resp_t
|
||||
C.rocm_init(&resp)
|
||||
if resp.err != nil {
|
||||
log.Printf("ROCm not detected: %s", C.GoString(resp.err))
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
log.Printf("Radeon GPU detected")
|
||||
rocm := resp.rh
|
||||
gpuHandles.rocm = &rocm
|
||||
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
||||
|
||||
gpuHandles = &handles{nil, nil}
|
||||
var cudaMgmtName string
|
||||
var cudaMgmtPatterns []string
|
||||
var rocmMgmtName string
|
||||
var rocmMgmtPatterns []string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
cudaMgmtName = "nvml.dll"
|
||||
cudaMgmtPatterns = make([]string, len(CudaWindowsGlobs))
|
||||
copy(cudaMgmtPatterns, CudaWindowsGlobs)
|
||||
rocmMgmtName = "rocm_smi64.dll"
|
||||
rocmMgmtPatterns = make([]string, len(RocmWindowsGlobs))
|
||||
copy(rocmMgmtPatterns, RocmWindowsGlobs)
|
||||
case "linux":
|
||||
cudaMgmtName = "libnvidia-ml.so"
|
||||
cudaMgmtPatterns = make([]string, len(CudaLinuxGlobs))
|
||||
copy(cudaMgmtPatterns, CudaLinuxGlobs)
|
||||
rocmMgmtName = "librocm_smi64.so"
|
||||
rocmMgmtPatterns = make([]string, len(RocmLinuxGlobs))
|
||||
copy(rocmMgmtPatterns, RocmLinuxGlobs)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("Detecting GPU type")
|
||||
cudaLibPaths := FindGPULibs(cudaMgmtName, cudaMgmtPatterns)
|
||||
if len(cudaLibPaths) > 0 {
|
||||
cuda := LoadCUDAMgmt(cudaLibPaths)
|
||||
if cuda != nil {
|
||||
slog.Info("Nvidia GPU detected")
|
||||
gpuHandles.cuda = cuda
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
rocmLibPaths := FindGPULibs(rocmMgmtName, rocmMgmtPatterns)
|
||||
if len(rocmLibPaths) > 0 {
|
||||
rocm := LoadROCMMgmt(rocmLibPaths)
|
||||
if rocm != nil {
|
||||
slog.Info("Radeon GPU detected")
|
||||
gpuHandles.rocm = rocm
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Printf("Nvidia GPU detected")
|
||||
cuda := resp.ch
|
||||
gpuHandles.cuda = &cuda
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,39 +122,84 @@ func GetGPUInfo() GpuInfo {
|
||||
initGPUHandles()
|
||||
}
|
||||
|
||||
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
||||
cpuVariant := GetCPUVariant()
|
||||
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
||||
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
|
||||
}
|
||||
|
||||
var memInfo C.mem_info_t
|
||||
resp := GpuInfo{}
|
||||
if gpuHandles.cuda != nil {
|
||||
if gpuHandles.cuda != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
||||
C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
|
||||
if memInfo.err != nil {
|
||||
log.Printf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))
|
||||
slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
resp.Library = "cuda"
|
||||
} else if memInfo.count > 0 {
|
||||
// Verify minimum compute capability
|
||||
var cc C.cuda_compute_capability_t
|
||||
C.cuda_compute_capability(*gpuHandles.cuda, &cc)
|
||||
if cc.err != nil {
|
||||
slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
|
||||
C.free(unsafe.Pointer(cc.err))
|
||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
resp.Library = "cuda"
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
}
|
||||
}
|
||||
} else if gpuHandles.rocm != nil {
|
||||
} else if gpuHandles.rocm != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
||||
C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
|
||||
if memInfo.err != nil {
|
||||
log.Printf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))
|
||||
slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
|
||||
// Only one GPU detected and it appears to be an integrated GPU - skip it
|
||||
slog.Info("ROCm unsupported integrated GPU detected")
|
||||
} else if memInfo.count > 0 {
|
||||
if memInfo.igpu_index >= 0 {
|
||||
// We have multiple GPUs reported, and one of them is an integrated GPU
|
||||
// so we have to set the env var to bypass it
|
||||
// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
|
||||
val := os.Getenv("ROCR_VISIBLE_DEVICES")
|
||||
if val == "" {
|
||||
devices := []string{}
|
||||
for i := 0; i < int(memInfo.count); i++ {
|
||||
if i == int(memInfo.igpu_index) {
|
||||
continue
|
||||
}
|
||||
devices = append(devices, strconv.Itoa(i))
|
||||
}
|
||||
val = strings.Join(devices, ",")
|
||||
os.Setenv("ROCR_VISIBLE_DEVICES", val)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
|
||||
}
|
||||
resp.Library = "rocm"
|
||||
var version C.rocm_version_resp_t
|
||||
C.rocm_get_version(*gpuHandles.rocm, &version)
|
||||
verString := C.GoString(version.str)
|
||||
if version.status == 0 {
|
||||
resp.Variant = "v" + verString
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("failed to look up ROCm version: %s", verString))
|
||||
}
|
||||
C.free(unsafe.Pointer(version.str))
|
||||
}
|
||||
}
|
||||
if resp.Library == "" {
|
||||
C.cpu_check_ram(&memInfo)
|
||||
// In the future we may offer multiple CPU variants to tune CPU features
|
||||
if runtime.GOOS == "windows" {
|
||||
resp.Library = "cpu"
|
||||
} else {
|
||||
resp.Library = "default"
|
||||
}
|
||||
resp.Library = "cpu"
|
||||
resp.Variant = cpuVariant
|
||||
}
|
||||
if memInfo.err != nil {
|
||||
log.Printf("error looking up CPU memory: %s", C.GoString(memInfo.err))
|
||||
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
return resp
|
||||
}
|
||||
|
||||
resp.DeviceCount = uint32(memInfo.count)
|
||||
resp.FreeMemory = uint64(memInfo.free)
|
||||
resp.TotalMemory = uint64(memInfo.total)
|
||||
return resp
|
||||
@@ -119,31 +221,111 @@ func getCPUMem() (memInfo, error) {
|
||||
func CheckVRAM() (int64, error) {
|
||||
gpuInfo := GetGPUInfo()
|
||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
||||
return int64(gpuInfo.FreeMemory), nil
|
||||
// leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead
|
||||
overhead := gpuInfo.FreeMemory / 10
|
||||
gpus := uint64(gpuInfo.DeviceCount)
|
||||
if overhead < gpus*1024*1024*1024 {
|
||||
overhead = gpus * 1024 * 1024 * 1024
|
||||
}
|
||||
avail := int64(gpuInfo.FreeMemory - overhead)
|
||||
slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
|
||||
return avail, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
||||
}
|
||||
|
||||
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
||||
if opts.NumGPU != -1 {
|
||||
return opts.NumGPU
|
||||
func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
||||
var ldPaths []string
|
||||
gpuLibPaths := []string{}
|
||||
slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
ldPaths = strings.Split(os.Getenv("PATH"), ";")
|
||||
case "linux":
|
||||
ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
|
||||
default:
|
||||
return gpuLibPaths
|
||||
}
|
||||
info := GetGPUInfo()
|
||||
if info.Library == "cpu" || info.Library == "default" {
|
||||
return 0
|
||||
// Start with whatever we find in the PATH/LD_LIBRARY_PATH
|
||||
for _, ldPath := range ldPaths {
|
||||
d, err := filepath.Abs(ldPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
|
||||
}
|
||||
|
||||
/*
|
||||
Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
|
||||
We can store the model weights and the kv cache in vram,
|
||||
to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
|
||||
*/
|
||||
bytesPerLayer := uint64(fileSizeBytes / numLayer)
|
||||
|
||||
// 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
|
||||
layers := int(info.FreeMemory/bytesPerLayer) * 3 / 4
|
||||
|
||||
log.Printf("%d MB VRAM available, loading up to %d %s GPU layers out of %d", info.FreeMemory/(1024*1024), layers, info.Library, numLayer)
|
||||
|
||||
return layers
|
||||
slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
|
||||
for _, pattern := range patterns {
|
||||
// Ignore glob discovery errors
|
||||
matches, _ := filepath.Glob(pattern)
|
||||
for _, match := range matches {
|
||||
// Resolve any links so we don't try the same lib multiple times
|
||||
// and weed out any dups across globs
|
||||
libPath := match
|
||||
tmp := match
|
||||
var err error
|
||||
for ; err == nil; tmp, err = os.Readlink(libPath) {
|
||||
if !filepath.IsAbs(tmp) {
|
||||
tmp = filepath.Join(filepath.Dir(libPath), tmp)
|
||||
}
|
||||
libPath = tmp
|
||||
}
|
||||
new := true
|
||||
for _, cmp := range gpuLibPaths {
|
||||
if cmp == libPath {
|
||||
new = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if new {
|
||||
gpuLibPaths = append(gpuLibPaths, libPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
|
||||
return gpuLibPaths
|
||||
}
|
||||
|
||||
func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
|
||||
var resp C.cuda_init_resp_t
|
||||
resp.ch.verbose = getVerboseState()
|
||||
for _, libPath := range cudaLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
C.cuda_init(lib, &resp)
|
||||
if resp.err != nil {
|
||||
slog.Info(fmt.Sprintf("Unable to load CUDA management library %s: %s", libPath, C.GoString(resp.err)))
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
return &resp.ch
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
|
||||
var resp C.rocm_init_resp_t
|
||||
resp.rh.verbose = getVerboseState()
|
||||
for _, libPath := range rocmLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
C.rocm_init(lib, &resp)
|
||||
if resp.err != nil {
|
||||
slog.Info(fmt.Sprintf("Unable to load ROCm management library %s: %s", libPath, C.GoString(resp.err)))
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
return &resp.rh
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getVerboseState() C.uint16_t {
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
return C.uint16_t(1)
|
||||
}
|
||||
return C.uint16_t(0)
|
||||
}
|
||||
|
@@ -6,21 +6,41 @@ import "C"
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/pbnjay/memory"
|
||||
)
|
||||
|
||||
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
||||
func CheckVRAM() (int64, error) {
|
||||
// TODO - assume metal, and return free memory?
|
||||
return 0, nil
|
||||
if runtime.GOARCH == "amd64" {
|
||||
// gpu not supported, this may not be metal
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// on macOS, there's already buffer for available vram (see below) so just return the total
|
||||
systemMemory := int64(memory.TotalMemory())
|
||||
|
||||
// macOS limits how much memory is available to the GPU based on the amount of system memory
|
||||
// TODO: handle case where iogpu.wired_limit_mb is set to a higher value
|
||||
if systemMemory <= 36*1024*1024*1024 {
|
||||
systemMemory = systemMemory * 2 / 3
|
||||
} else {
|
||||
systemMemory = systemMemory * 3 / 4
|
||||
}
|
||||
|
||||
return systemMemory, nil
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfo {
|
||||
// TODO - Metal vs. x86 macs...
|
||||
mem, _ := getCPUMem()
|
||||
if runtime.GOARCH == "amd64" {
|
||||
return GpuInfo{
|
||||
Library: "cpu",
|
||||
Variant: GetCPUVariant(),
|
||||
memInfo: mem,
|
||||
}
|
||||
}
|
||||
return GpuInfo{
|
||||
Library: "default",
|
||||
Library: "metal",
|
||||
memInfo: mem,
|
||||
}
|
||||
}
|
||||
@@ -29,22 +49,6 @@ func getCPUMem() (memInfo, error) {
|
||||
return memInfo{
|
||||
TotalMemory: 0,
|
||||
FreeMemory: 0,
|
||||
DeviceCount: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
||||
if opts.NumGPU != -1 {
|
||||
return opts.NumGPU
|
||||
}
|
||||
|
||||
// metal only supported on arm64
|
||||
if runtime.GOARCH == "arm64" {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func nativeInit() error {
|
||||
return nil
|
||||
}
|
||||
|
@@ -27,6 +27,13 @@
|
||||
|
||||
#endif
|
||||
|
||||
#define LOG(verbose, ...) \
|
||||
do { \
|
||||
if (verbose) { \
|
||||
fprintf(stderr, __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -34,6 +41,8 @@ extern "C" {
|
||||
typedef struct mem_info {
|
||||
uint64_t total;
|
||||
uint64_t free;
|
||||
unsigned int count;
|
||||
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
|
||||
char *err; // If non-nill, caller responsible for freeing
|
||||
} mem_info_t;
|
||||
|
||||
|
@@ -8,6 +8,7 @@ void cpu_check_ram(mem_info_t *resp) {
|
||||
MEMORYSTATUSEX info;
|
||||
info.dwLength = sizeof(info);
|
||||
if (GlobalMemoryStatusEx(&info) != 0) {
|
||||
resp->count = 1;
|
||||
resp->total = info.ullTotalPhys;
|
||||
resp->free = info.ullAvailPhys;
|
||||
} else {
|
||||
@@ -26,6 +27,7 @@ void cpu_check_ram(mem_info_t *resp) {
|
||||
if (sysinfo(&info) != 0) {
|
||||
resp->err = strdup(strerror(errno));
|
||||
} else {
|
||||
resp->count = 1;
|
||||
resp->total = info.totalram * info.mem_unit;
|
||||
resp->free = info.freeram * info.mem_unit;
|
||||
}
|
||||
|
@@ -4,23 +4,7 @@
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
const char *cuda_lib_paths[] = {
|
||||
"libnvidia-ml.so",
|
||||
"/usr/local/cuda/lib64/libnvidia-ml.so",
|
||||
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so",
|
||||
"/usr/lib/wsl/lib/libnvidia-ml.so.1", // TODO Maybe glob?
|
||||
NULL,
|
||||
};
|
||||
#else
|
||||
const char *cuda_lib_paths[] = {
|
||||
"nvml.dll",
|
||||
"",
|
||||
NULL,
|
||||
};
|
||||
#endif
|
||||
|
||||
void cuda_init(cuda_init_resp_t *resp) {
|
||||
void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
|
||||
nvmlReturn_t ret;
|
||||
resp->err = NULL;
|
||||
const int buflen = 256;
|
||||
@@ -30,34 +14,47 @@ void cuda_init(cuda_init_resp_t *resp) {
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[4] = {
|
||||
{"nvmlInit_v2", (void *)&resp->ch.initFn},
|
||||
{"nvmlShutdown", (void *)&resp->ch.shutdownFn},
|
||||
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.getHandle},
|
||||
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.getMemInfo},
|
||||
} l[] = {
|
||||
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
|
||||
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
|
||||
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
|
||||
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
|
||||
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
|
||||
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
|
||||
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
|
||||
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
|
||||
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
|
||||
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
|
||||
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
|
||||
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
for (i = 0; cuda_lib_paths[i] != NULL && resp->ch.handle == NULL; i++) {
|
||||
resp->ch.handle = LOAD_LIBRARY(cuda_lib_paths[i], RTLD_LAZY);
|
||||
}
|
||||
resp->ch.handle = LOAD_LIBRARY(cuda_lib_path, RTLD_LAZY);
|
||||
if (!resp->ch.handle) {
|
||||
// TODO improve error message, as the LOAD_ERR will have typically have the
|
||||
// final path that was checked which might be confusing.
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "library %s load err: %s\n", cuda_lib_path, msg);
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
||||
cuda_lib_paths[0], msg);
|
||||
cuda_lib_path, msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; i < 4; i++) { // TODO - fix this to use a null terminated list
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", cuda_lib_path);
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
||||
|
||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
msg);
|
||||
free(msg);
|
||||
@@ -66,13 +63,23 @@ void cuda_init(cuda_init_resp_t *resp) {
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*resp->ch.initFn)();
|
||||
ret = (*resp->ch.nvmlInit_v2)();
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
return;
|
||||
// Report driver version if we're in verbose mode, ignore errors
|
||||
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
|
||||
@@ -89,22 +96,119 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO - handle multiple GPUs
|
||||
ret = (*h.getHandle)(0, &device);
|
||||
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device handle: %d", ret);
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.getMemInfo)(device, &memInfo);
|
||||
resp->total = 0;
|
||||
resp->free = 0;
|
||||
for (i = 0; i < resp->count; i++) {
|
||||
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
if (h.verbose) {
|
||||
nvmlBrandType_t brand = 0;
|
||||
// When in verbose mode, report more information about
|
||||
// the card we discover, but don't fail on error
|
||||
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
|
||||
}
|
||||
}
|
||||
|
||||
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
|
||||
LOG(h.verbose, "[%d] CUDA usedMem %ld\n", i, memInfo.free);
|
||||
|
||||
resp->total += memInfo.total;
|
||||
resp->free += memInfo.free;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_compute_capability(cuda_handle_t h, cuda_compute_capability_t *resp) {
|
||||
resp->err = NULL;
|
||||
resp->major = 0;
|
||||
resp->minor = 0;
|
||||
nvmlDevice_t device;
|
||||
int major = 0;
|
||||
int minor = 0;
|
||||
nvmlReturn_t ret;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
if (h.handle == NULL) {
|
||||
resp->err = strdup("nvml handle not initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned int devices;
|
||||
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "device memory info lookup failure: %d", ret);
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
resp->total = memInfo.total;
|
||||
resp->free = memInfo.free;
|
||||
return;
|
||||
|
||||
for (i = 0; i < devices; i++) {
|
||||
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
// Report the lowest major.minor we detect as that limits our compatibility
|
||||
if (resp->major == 0 || resp->major > major ) {
|
||||
resp->major = major;
|
||||
resp->minor = minor;
|
||||
} else if ( resp->major == major && resp->minor > minor ) {
|
||||
resp->minor = minor;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // __APPLE__
|
@@ -15,12 +15,26 @@ typedef struct nvmlMemory_st {
|
||||
unsigned long long used;
|
||||
} nvmlMemory_t;
|
||||
|
||||
typedef enum nvmlBrandType_enum
|
||||
{
|
||||
NVML_BRAND_UNKNOWN = 0,
|
||||
} nvmlBrandType_t;
|
||||
|
||||
typedef struct cuda_handle {
|
||||
void *handle;
|
||||
nvmlReturn_t (*initFn)(void);
|
||||
nvmlReturn_t (*shutdownFn)(void);
|
||||
nvmlReturn_t (*getHandle)(unsigned int, nvmlDevice_t *);
|
||||
nvmlReturn_t (*getMemInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
uint16_t verbose;
|
||||
nvmlReturn_t (*nvmlInit_v2)(void);
|
||||
nvmlReturn_t (*nvmlShutdown)(void);
|
||||
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
|
||||
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
|
||||
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
|
||||
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
|
||||
} cuda_handle_t;
|
||||
|
||||
typedef struct cuda_init_resp {
|
||||
@@ -28,8 +42,15 @@ typedef struct cuda_init_resp {
|
||||
cuda_handle_t ch;
|
||||
} cuda_init_resp_t;
|
||||
|
||||
void cuda_init(cuda_init_resp_t *resp);
|
||||
typedef struct cuda_compute_capability {
|
||||
char *err;
|
||||
int major;
|
||||
int minor;
|
||||
} cuda_compute_capability_t;
|
||||
|
||||
void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp);
|
||||
void cuda_check_vram(cuda_handle_t ch, mem_info_t *resp);
|
||||
void cuda_compute_capability(cuda_handle_t ch, cuda_compute_capability_t *cc);
|
||||
|
||||
#endif // __GPU_INFO_CUDA_H__
|
||||
#endif // __APPLE__
|
@@ -4,22 +4,7 @@
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
const char *rocm_lib_paths[] = {
|
||||
"librocm_smi64.so",
|
||||
"/opt/rocm/lib/librocm_smi64.so",
|
||||
NULL,
|
||||
};
|
||||
#else
|
||||
// TODO untested
|
||||
const char *rocm_lib_paths[] = {
|
||||
"rocm_smi64.dll",
|
||||
"/opt/rocm/lib/rocm_smi64.dll",
|
||||
NULL,
|
||||
};
|
||||
#endif
|
||||
|
||||
void rocm_init(rocm_init_resp_t *resp) {
|
||||
void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
|
||||
rsmi_status_t ret;
|
||||
resp->err = NULL;
|
||||
const int buflen = 256;
|
||||
@@ -28,32 +13,48 @@ void rocm_init(rocm_init_resp_t *resp) {
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[4] = {
|
||||
{"rsmi_init", (void *)&resp->rh.initFn},
|
||||
{"rsmi_shut_down", (void *)&resp->rh.shutdownFn},
|
||||
{"rsmi_dev_memory_total_get", (void *)&resp->rh.totalMemFn},
|
||||
{"rsmi_dev_memory_usage_get", (void *)&resp->rh.usageMemFn},
|
||||
// { "rsmi_dev_id_get", (void*)&resp->rh.getHandle },
|
||||
} l[] = {
|
||||
{"rsmi_init", (void *)&resp->rh.rsmi_init},
|
||||
{"rsmi_shut_down", (void *)&resp->rh.rsmi_shut_down},
|
||||
{"rsmi_dev_memory_total_get", (void *)&resp->rh.rsmi_dev_memory_total_get},
|
||||
{"rsmi_dev_memory_usage_get", (void *)&resp->rh.rsmi_dev_memory_usage_get},
|
||||
{"rsmi_version_get", (void *)&resp->rh.rsmi_version_get},
|
||||
{"rsmi_num_monitor_devices", (void*)&resp->rh.rsmi_num_monitor_devices},
|
||||
{"rsmi_dev_id_get", (void*)&resp->rh.rsmi_dev_id_get},
|
||||
{"rsmi_dev_name_get", (void *)&resp->rh.rsmi_dev_name_get},
|
||||
{"rsmi_dev_brand_get", (void *)&resp->rh.rsmi_dev_brand_get},
|
||||
{"rsmi_dev_vendor_name_get", (void *)&resp->rh.rsmi_dev_vendor_name_get},
|
||||
{"rsmi_dev_vram_vendor_get", (void *)&resp->rh.rsmi_dev_vram_vendor_get},
|
||||
{"rsmi_dev_serial_number_get", (void *)&resp->rh.rsmi_dev_serial_number_get},
|
||||
{"rsmi_dev_subsystem_name_get", (void *)&resp->rh.rsmi_dev_subsystem_name_get},
|
||||
{"rsmi_dev_vbios_version_get", (void *)&resp->rh.rsmi_dev_vbios_version_get},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
for (i = 0; rocm_lib_paths[i] != NULL && resp->rh.handle == NULL; i++) {
|
||||
resp->rh.handle = LOAD_LIBRARY(rocm_lib_paths[i], RTLD_LAZY);
|
||||
}
|
||||
resp->rh.handle = LOAD_LIBRARY(rocm_lib_path, RTLD_LAZY);
|
||||
if (!resp->rh.handle) {
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Radeon GPUs: %s\n",
|
||||
rocm_lib_paths[0], msg);
|
||||
rocm_lib_path, msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; i < 4; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->rh.verbose, "wiring rocm management library functions in %s\n", rocm_lib_path);
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->rh.verbose, "dlsym: %s\n", l[i].s);
|
||||
|
||||
*l[i].p = LOAD_SYMBOL(resp->rh.handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
UNLOAD_LIBRARY(resp->rh.handle);
|
||||
resp->rh.handle = NULL;
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->rh.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->rh.handle);
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
msg);
|
||||
free(msg);
|
||||
@@ -62,8 +63,11 @@ void rocm_init(rocm_init_resp_t *resp) {
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*resp->rh.initFn)(0);
|
||||
ret = (*resp->rh.rsmi_init)(0);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(resp->rh.verbose, "rsmi_init err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->rh.handle);
|
||||
resp->rh.handle = NULL;
|
||||
snprintf(buf, buflen, "rocm vram init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
}
|
||||
@@ -73,8 +77,7 @@ void rocm_init(rocm_init_resp_t *resp) {
|
||||
|
||||
void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
// uint32_t num_devices;
|
||||
// uint16_t device;
|
||||
resp->igpu_index = -1;
|
||||
uint64_t totalMem = 0;
|
||||
uint64_t usedMem = 0;
|
||||
rsmi_status_t ret;
|
||||
@@ -83,36 +86,113 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
|
||||
int i;
|
||||
|
||||
if (h.handle == NULL) {
|
||||
resp->err = strdup("nvml handle sn't initialized");
|
||||
resp->err = strdup("rocm handle not initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO - iterate through devices... ret =
|
||||
// rsmi_num_monitor_devices(&num_devices);
|
||||
|
||||
// ret = (*h.getHandle)(0, &device);
|
||||
// if (ret != RSMI_STATUS_SUCCESS) {
|
||||
// printf("rocm vram device lookup failure: %d\n", ret);
|
||||
// return -1;
|
||||
// }
|
||||
|
||||
// Get total memory - used memory for available memory
|
||||
ret = (*h.totalMemFn)(0, RSMI_MEM_TYPE_VRAM, &totalMem);
|
||||
ret = (*h.rsmi_num_monitor_devices)(&resp->count);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
snprintf(buf, buflen, "rocm total mem lookup failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
ret = (*h.usageMemFn)(0, RSMI_MEM_TYPE_VRAM, &usedMem);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
snprintf(buf, buflen, "rocm usage mem lookup failure: %d", ret);
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
LOG(h.verbose, "discovered %d ROCm GPU Devices\n", resp->count);
|
||||
|
||||
resp->total = totalMem;
|
||||
resp->free = totalMem - usedMem;
|
||||
return;
|
||||
resp->total = 0;
|
||||
resp->free = 0;
|
||||
for (i = 0; i < resp->count; i++) {
|
||||
if (h.verbose) {
|
||||
// When in verbose mode, report more information about
|
||||
// the card we discover, but don't fail on error
|
||||
ret = (*h.rsmi_dev_name_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_name_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm device name: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_brand_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_brand_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm brand: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_vendor_name_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_vendor_name_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm vendor: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_vram_vendor_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_vram_vendor_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm VRAM vendor: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_serial_number_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_serial_number_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm S/N: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_subsystem_name_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_subsystem_name_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm subsystem name: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_vbios_version_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_vbios_version_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm vbios version: %s\n", i, buf);
|
||||
}
|
||||
}
|
||||
|
||||
// Get total memory - used memory for available memory
|
||||
ret = (*h.rsmi_dev_memory_total_get)(i, RSMI_MEM_TYPE_VRAM, &totalMem);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
snprintf(buf, buflen, "rocm total mem lookup failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
ret = (*h.rsmi_dev_memory_usage_get)(i, RSMI_MEM_TYPE_VRAM, &usedMem);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
snprintf(buf, buflen, "rocm usage mem lookup failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem);
|
||||
LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem);
|
||||
if (totalMem < 1024 * 1024 * 1024) {
|
||||
// Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory
|
||||
LOG(h.verbose, "[%d] ROCm integrated GPU\n", i);
|
||||
resp->igpu_index = i;
|
||||
} else {
|
||||
resp->total += totalMem;
|
||||
resp->free += totalMem - usedMem;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
if (h.handle == NULL) {
|
||||
resp->str = strdup("rocm handle not initialized");
|
||||
resp->status = 1;
|
||||
return;
|
||||
}
|
||||
rsmi_version_t ver;
|
||||
rsmi_status_t ret;
|
||||
ret = h.rsmi_version_get(&ver);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
snprintf(buf, buflen, "unexpected response on version lookup %d", ret);
|
||||
resp->status = 1;
|
||||
} else {
|
||||
snprintf(buf, buflen, "%d", ver.major);
|
||||
resp->status = 0;
|
||||
}
|
||||
resp->str = strdup(buf);
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
|
@@ -15,13 +15,30 @@ typedef enum rsmi_memory_type {
|
||||
RSMI_MEM_TYPE_GTT,
|
||||
} rsmi_memory_type_t;
|
||||
|
||||
typedef struct {
|
||||
uint32_t major;
|
||||
uint32_t minor;
|
||||
uint32_t patch;
|
||||
const char *build;
|
||||
} rsmi_version_t;
|
||||
|
||||
typedef struct rocm_handle {
|
||||
void *handle;
|
||||
rsmi_status_t (*initFn)(uint64_t);
|
||||
rsmi_status_t (*shutdownFn)(void);
|
||||
rsmi_status_t (*totalMemFn)(uint32_t, rsmi_memory_type_t, uint64_t *);
|
||||
rsmi_status_t (*usageMemFn)(uint32_t, rsmi_memory_type_t, uint64_t *);
|
||||
// rsmi_status_t (*getHandle)(uint32_t, uint16_t *);
|
||||
uint16_t verbose;
|
||||
rsmi_status_t (*rsmi_init)(uint64_t);
|
||||
rsmi_status_t (*rsmi_shut_down)(void);
|
||||
rsmi_status_t (*rsmi_dev_memory_total_get)(uint32_t, rsmi_memory_type_t, uint64_t *);
|
||||
rsmi_status_t (*rsmi_dev_memory_usage_get)(uint32_t, rsmi_memory_type_t, uint64_t *);
|
||||
rsmi_status_t (*rsmi_version_get) (rsmi_version_t *version);
|
||||
rsmi_status_t (*rsmi_num_monitor_devices) (uint32_t *);
|
||||
rsmi_status_t (*rsmi_dev_id_get)(uint32_t, uint16_t *);
|
||||
rsmi_status_t (*rsmi_dev_name_get) (uint32_t,char *,size_t);
|
||||
rsmi_status_t (*rsmi_dev_brand_get) (uint32_t, char *, uint32_t);
|
||||
rsmi_status_t (*rsmi_dev_vendor_name_get) (uint32_t, char *, uint32_t);
|
||||
rsmi_status_t (*rsmi_dev_vram_vendor_get) (uint32_t, char *, uint32_t);
|
||||
rsmi_status_t (*rsmi_dev_serial_number_get) (uint32_t, char *, uint32_t);
|
||||
rsmi_status_t (*rsmi_dev_subsystem_name_get) (uint32_t, char *, uint32_t);
|
||||
rsmi_status_t (*rsmi_dev_vbios_version_get) (uint32_t, char *, uint32_t);
|
||||
} rocm_handle_t;
|
||||
|
||||
typedef struct rocm_init_resp {
|
||||
@@ -29,8 +46,14 @@ typedef struct rocm_init_resp {
|
||||
rocm_handle_t rh;
|
||||
} rocm_init_resp_t;
|
||||
|
||||
void rocm_init(rocm_init_resp_t *resp);
|
||||
typedef struct rocm_version_resp {
|
||||
rsmi_status_t status;
|
||||
char *str; // Contains version or error string if status != 0
|
||||
} rocm_version_resp_t;
|
||||
|
||||
void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp);
|
||||
void rocm_check_vram(rocm_handle_t rh, mem_info_t *resp);
|
||||
void rocm_get_version(rocm_handle_t rh, rocm_version_resp_t *resp);
|
||||
|
||||
#endif // __GPU_INFO_ROCM_H__
|
||||
#endif // __APPLE__
|
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
func TestBasicGetGPUInfo(t *testing.T) {
|
||||
info := GetGPUInfo()
|
||||
assert.Contains(t, "cuda rocm cpu default", info.Library)
|
||||
assert.Contains(t, "cuda rocm cpu metal", info.Library)
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
@@ -18,6 +18,7 @@ func TestBasicGetGPUInfo(t *testing.T) {
|
||||
case "linux", "windows":
|
||||
assert.Greater(t, info.TotalMemory, uint64(0))
|
||||
assert.Greater(t, info.FreeMemory, uint64(0))
|
||||
assert.Greater(t, info.DeviceCount, uint32(0))
|
||||
default:
|
||||
return
|
||||
}
|
||||
@@ -35,7 +36,6 @@ func TestCPUMemInfo(t *testing.T) {
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TODO - add some logic to figure out card type through other means and actually verify we got back what we expected
|
||||
|
@@ -3,6 +3,7 @@ package gpu
|
||||
type memInfo struct {
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
DeviceCount uint32 `json:"device_count,omitempty"`
|
||||
}
|
||||
|
||||
// Beginning of an `ollama info` command
|
||||
@@ -10,5 +11,8 @@ type GpuInfo struct {
|
||||
memInfo
|
||||
Library string `json:"library,omitempty"`
|
||||
|
||||
// Optional variant to select (e.g. versions, cpu feature flags)
|
||||
Variant string `json:"variant,omitempty"`
|
||||
|
||||
// TODO add other useful attributes about the card here for discovery information
|
||||
}
|
||||
|
@@ -1,11 +1,11 @@
|
||||
#include "dynamic_shim.h"
|
||||
#include "dyn_ext_server.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef __linux__
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags | RTLD_DEEPBIND)
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
@@ -33,7 +33,7 @@ inline char *LOAD_ERR() {
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#endif
|
||||
|
||||
void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
void dyn_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
ext_server_resp_t *err) {
|
||||
int i = 0;
|
||||
struct lookup {
|
||||
@@ -58,8 +58,8 @@ void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
{"", NULL},
|
||||
};
|
||||
|
||||
printf("Lazy loading %s library\n", libPath);
|
||||
s->handle = LOAD_LIBRARY(libPath, RTLD_NOW);
|
||||
printf("loading library %s\n", libPath);
|
||||
s->handle = LOAD_LIBRARY(libPath, RTLD_LOCAL|RTLD_NOW);
|
||||
if (!s->handle) {
|
||||
err->id = -1;
|
||||
char *msg = LOAD_ERR();
|
||||
@@ -83,63 +83,63 @@ void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
}
|
||||
}
|
||||
|
||||
inline void dynamic_shim_llama_server_init(struct dynamic_llama_server s,
|
||||
inline void dyn_llama_server_init(struct dynamic_llama_server s,
|
||||
ext_server_params_t *sparams,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_init(sparams, err);
|
||||
}
|
||||
|
||||
inline void dynamic_shim_llama_server_start(struct dynamic_llama_server s) {
|
||||
inline void dyn_llama_server_start(struct dynamic_llama_server s) {
|
||||
s.llama_server_start();
|
||||
}
|
||||
|
||||
inline void dynamic_shim_llama_server_stop(struct dynamic_llama_server s) {
|
||||
inline void dyn_llama_server_stop(struct dynamic_llama_server s) {
|
||||
s.llama_server_stop();
|
||||
}
|
||||
|
||||
inline void dynamic_shim_llama_server_completion(struct dynamic_llama_server s,
|
||||
inline void dyn_llama_server_completion(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
ext_server_resp_t *resp) {
|
||||
s.llama_server_completion(json_req, resp);
|
||||
}
|
||||
|
||||
inline void dynamic_shim_llama_server_completion_next_result(
|
||||
inline void dyn_llama_server_completion_next_result(
|
||||
struct dynamic_llama_server s, const int task_id,
|
||||
ext_server_task_result_t *result) {
|
||||
s.llama_server_completion_next_result(task_id, result);
|
||||
}
|
||||
|
||||
inline void dynamic_shim_llama_server_completion_cancel(
|
||||
inline void dyn_llama_server_completion_cancel(
|
||||
struct dynamic_llama_server s, const int task_id, ext_server_resp_t *err) {
|
||||
s.llama_server_completion_cancel(task_id, err);
|
||||
}
|
||||
inline void dynamic_shim_llama_server_release_task_result(
|
||||
inline void dyn_llama_server_release_task_result(
|
||||
struct dynamic_llama_server s, ext_server_task_result_t *result) {
|
||||
s.llama_server_release_task_result(result);
|
||||
}
|
||||
|
||||
inline void dynamic_shim_llama_server_tokenize(struct dynamic_llama_server s,
|
||||
inline void dyn_llama_server_tokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_tokenize(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dynamic_shim_llama_server_detokenize(struct dynamic_llama_server s,
|
||||
inline void dyn_llama_server_detokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_detokenize(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dynamic_shim_llama_server_embedding(struct dynamic_llama_server s,
|
||||
inline void dyn_llama_server_embedding(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_embedding(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dynamic_shim_llama_server_release_json_resp(
|
||||
inline void dyn_llama_server_release_json_resp(
|
||||
struct dynamic_llama_server s, char **json_resp) {
|
||||
s.llama_server_release_json_resp(json_resp);
|
||||
}
|
@@ -1,63 +1,45 @@
|
||||
package llm
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/gguf -I${SRCDIR}/llama.cpp/gguf/common -I${SRCDIR}/llama.cpp/gguf/examples/server
|
||||
#cgo CFLAGS: -I${SRCDIR}/ext_server -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/common -I${SRCDIR}/llama.cpp/examples/server
|
||||
#cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||
#cgo CFLAGS: -Wmissing-noreturn -Wall -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
|
||||
#cgo CPPFLAGS: -Ofast -Wall -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations -Wno-unused-but-set-variable
|
||||
#cgo CFLAGS: -Wmissing-noreturn -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
|
||||
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations
|
||||
#cgo darwin CFLAGS: -D_DARWIN_C_SOURCE
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
||||
#cgo darwin LDFLAGS: -lc++ -framework Accelerate
|
||||
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/darwin/metal/lib/libcommon.a
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/darwin/metal/lib/libext_server.a
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/darwin/metal/lib/libllama.a
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/darwin/metal/lib/libggml_static.a
|
||||
#cgo linux CFLAGS: -D_GNU_SOURCE
|
||||
#cgo linux windows CFLAGS: -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS
|
||||
#cgo linux LDFLAGS: -L/usr/local/cuda/targets/x86_64-linux/lib -L/usr/local/cuda/lib64 -L/usr/local/cuda/targets/x86_64-linux/lib/stubs
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/linux/cpu/lib/libext_server.a
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/linux/cpu/lib/libcommon.a
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/linux/cpu/lib/libllama.a
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/linux/cpu/lib/libggml_static.a
|
||||
#cgo linux LDFLAGS: -lrt -ldl -lstdc++ -lm
|
||||
#cgo linux windows LDFLAGS: -lpthread
|
||||
|
||||
#include <stdlib.h>
|
||||
#include "ext_server.h"
|
||||
#include "dyn_ext_server.h"
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/gpu"
|
||||
)
|
||||
|
||||
type extServer interface {
|
||||
LLM
|
||||
llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t)
|
||||
llama_server_start()
|
||||
llama_server_stop()
|
||||
llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t)
|
||||
llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t)
|
||||
llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t)
|
||||
llama_server_release_task_result(result *C.ext_server_task_result_t)
|
||||
llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
|
||||
llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
|
||||
llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
|
||||
llama_server_release_json_resp(json_resp **C.char)
|
||||
type dynExtServer struct {
|
||||
s C.struct_dynamic_llama_server
|
||||
options api.Options
|
||||
}
|
||||
|
||||
// Note: current implementation does not support concurrent instantiations
|
||||
@@ -82,25 +64,39 @@ func extServerResponseToErr(resp C.ext_server_resp_t) error {
|
||||
return fmt.Errorf(C.GoString(resp.msg))
|
||||
}
|
||||
|
||||
func newExtServer(server extServer, model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
// Note: current implementation does not support concurrent instantiations
|
||||
var llm *dynExtServer
|
||||
|
||||
func newDynExtServer(library, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
|
||||
if !mutex.TryLock() {
|
||||
log.Printf("concurrent llm servers not yet supported, waiting for prior server to complete")
|
||||
slog.Info("concurrent llm servers not yet supported, waiting for prior server to complete")
|
||||
mutex.Lock()
|
||||
}
|
||||
fileInfo, err := os.Stat(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
updatePath(filepath.Dir(library))
|
||||
libPath := C.CString(library)
|
||||
defer C.free(unsafe.Pointer(libPath))
|
||||
resp := newExtServerResp(512)
|
||||
defer freeExtServerResp(resp)
|
||||
var srv C.struct_dynamic_llama_server
|
||||
C.dyn_init(libPath, &srv, &resp)
|
||||
if resp.id < 0 {
|
||||
mutex.Unlock()
|
||||
return nil, fmt.Errorf("Unable to load dynamic library: %s", C.GoString(resp.msg))
|
||||
}
|
||||
llm = &dynExtServer{
|
||||
s: srv,
|
||||
options: opts,
|
||||
}
|
||||
slog.Info(fmt.Sprintf("Loading Dynamic llm server: %s", library))
|
||||
|
||||
var sparams C.ext_server_params_t
|
||||
sparams.model = C.CString(model)
|
||||
defer C.free(unsafe.Pointer(sparams.model))
|
||||
|
||||
numGPU := gpu.NumGPU(numLayers, fileInfo.Size(), opts)
|
||||
|
||||
sparams.embedding = true
|
||||
sparams.n_ctx = C.uint(opts.NumCtx)
|
||||
sparams.n_batch = C.uint(opts.NumBatch)
|
||||
sparams.n_gpu_layers = C.int(numGPU)
|
||||
sparams.n_gpu_layers = C.int(opts.NumGPU)
|
||||
sparams.main_gpu = C.int(opts.MainGPU)
|
||||
sparams.n_parallel = 1 // TODO - wire up concurrency
|
||||
|
||||
@@ -140,29 +136,35 @@ func newExtServer(server extServer, model string, adapters, projectors []string,
|
||||
|
||||
sparams.n_threads = C.uint(opts.NumThread)
|
||||
|
||||
log.Printf("Initializing internal llama server")
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
server.llama_server_init(&sparams, &resp)
|
||||
if resp.id < 0 {
|
||||
return nil, extServerResponseToErr(resp)
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
sparams.verbose_logging = C.bool(true)
|
||||
} else {
|
||||
sparams.verbose_logging = C.bool(false)
|
||||
}
|
||||
|
||||
log.Printf("Starting internal llama main loop")
|
||||
server.llama_server_start()
|
||||
return server, nil
|
||||
slog.Info("Initializing llama server")
|
||||
initResp := newExtServerResp(128)
|
||||
defer freeExtServerResp(initResp)
|
||||
C.dyn_llama_server_init(llm.s, &sparams, &initResp)
|
||||
if initResp.id < 0 {
|
||||
mutex.Unlock()
|
||||
err := extServerResponseToErr(initResp)
|
||||
slog.Debug(fmt.Sprintf("failure during initialization: %s", err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info("Starting llama main loop")
|
||||
C.dyn_llama_server_start(llm.s)
|
||||
return llm, nil
|
||||
}
|
||||
|
||||
func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(PredictResult)) error {
|
||||
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
var imageData []ImageData
|
||||
|
||||
if len(predict.Images) > 0 {
|
||||
for cnt, i := range predict.Images {
|
||||
imageData = append(imageData, ImageData{Data: i, ID: cnt})
|
||||
}
|
||||
slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
|
||||
}
|
||||
log.Printf("loaded %d images", len(imageData))
|
||||
|
||||
request := map[string]any{
|
||||
"prompt": predict.Prompt,
|
||||
@@ -184,7 +186,7 @@ func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(Pr
|
||||
"penalize_nl": predict.Options.PenalizeNewline,
|
||||
"seed": predict.Options.Seed,
|
||||
"stop": predict.Options.Stop,
|
||||
"image_data": imageData,
|
||||
"image_data": predict.Images,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
@@ -211,7 +213,7 @@ func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(Pr
|
||||
req := C.CString(buffer.String())
|
||||
defer C.free(unsafe.Pointer(req))
|
||||
|
||||
llm.llama_server_completion(req, &resp)
|
||||
C.dyn_llama_server_completion(llm.s, req, &resp)
|
||||
if resp.id < 0 {
|
||||
return extServerResponseToErr(resp)
|
||||
}
|
||||
@@ -222,7 +224,7 @@ func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(Pr
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// This handles the request cancellation
|
||||
llm.llama_server_completion_cancel(resp.id, &resp)
|
||||
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
|
||||
if resp.id < 0 {
|
||||
return extServerResponseToErr(resp)
|
||||
} else {
|
||||
@@ -230,13 +232,13 @@ func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(Pr
|
||||
}
|
||||
default:
|
||||
var result C.ext_server_task_result_t
|
||||
llm.llama_server_completion_next_result(resp.id, &result)
|
||||
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
|
||||
json_resp := C.GoString(result.json_resp)
|
||||
llm.llama_server_release_task_result(&result)
|
||||
C.dyn_llama_server_release_task_result(llm.s, &result)
|
||||
|
||||
var p prediction
|
||||
if err := json.Unmarshal([]byte(json_resp), &p); err != nil {
|
||||
llm.llama_server_completion_cancel(resp.id, &resp)
|
||||
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
|
||||
if resp.id < 0 {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg))
|
||||
} else {
|
||||
@@ -277,7 +279,7 @@ func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(Pr
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
|
||||
func encode(llm extServer, ctx context.Context, prompt string) ([]int, error) {
|
||||
func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling encode data: %w", err)
|
||||
@@ -287,11 +289,11 @@ func encode(llm extServer, ctx context.Context, prompt string) ([]int, error) {
|
||||
var json_resp *C.char
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
llm.llama_server_tokenize(req, &json_resp, &resp)
|
||||
C.dyn_llama_server_tokenize(llm.s, req, &json_resp, &resp)
|
||||
if resp.id < 0 {
|
||||
return nil, extServerResponseToErr(resp)
|
||||
}
|
||||
defer llm.llama_server_release_json_resp(&json_resp)
|
||||
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
|
||||
|
||||
var encoded TokenizeResponse
|
||||
if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &encoded); err2 != nil {
|
||||
@@ -301,7 +303,7 @@ func encode(llm extServer, ctx context.Context, prompt string) ([]int, error) {
|
||||
return encoded.Tokens, err
|
||||
}
|
||||
|
||||
func decode(llm extServer, ctx context.Context, tokens []int) (string, error) {
|
||||
func (llm *dynExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
if len(tokens) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
@@ -315,11 +317,11 @@ func decode(llm extServer, ctx context.Context, tokens []int) (string, error) {
|
||||
var json_resp *C.char
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
llm.llama_server_detokenize(req, &json_resp, &resp)
|
||||
C.dyn_llama_server_detokenize(llm.s, req, &json_resp, &resp)
|
||||
if resp.id < 0 {
|
||||
return "", extServerResponseToErr(resp)
|
||||
}
|
||||
defer llm.llama_server_release_json_resp(&json_resp)
|
||||
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
|
||||
|
||||
var decoded DetokenizeResponse
|
||||
if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &decoded); err2 != nil {
|
||||
@@ -329,7 +331,7 @@ func decode(llm extServer, ctx context.Context, tokens []int) (string, error) {
|
||||
return decoded.Content, err
|
||||
}
|
||||
|
||||
func embedding(llm extServer, ctx context.Context, input string) ([]float64, error) {
|
||||
func (llm *dynExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
data, err := json.Marshal(TokenizeRequest{Content: input})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
@@ -340,11 +342,11 @@ func embedding(llm extServer, ctx context.Context, input string) ([]float64, err
|
||||
var json_resp *C.char
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
llm.llama_server_embedding(req, &json_resp, &resp)
|
||||
C.dyn_llama_server_embedding(llm.s, req, &json_resp, &resp)
|
||||
if resp.id < 0 {
|
||||
return nil, extServerResponseToErr(resp)
|
||||
}
|
||||
defer llm.llama_server_release_json_resp(&json_resp)
|
||||
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
|
||||
|
||||
var embedding EmbeddingResponse
|
||||
if err := json.Unmarshal([]byte(C.GoString(json_resp)), &embedding); err != nil {
|
||||
@@ -354,7 +356,29 @@ func embedding(llm extServer, ctx context.Context, input string) ([]float64, err
|
||||
return embedding.Embedding, nil
|
||||
}
|
||||
|
||||
func close(llm extServer) {
|
||||
llm.llama_server_stop()
|
||||
func (llm *dynExtServer) Close() {
|
||||
C.dyn_llama_server_stop(llm.s)
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
func updatePath(dir string) {
|
||||
if runtime.GOOS == "windows" {
|
||||
tmpDir := filepath.Dir(dir)
|
||||
pathComponents := strings.Split(os.Getenv("PATH"), ";")
|
||||
i := 0
|
||||
for _, comp := range pathComponents {
|
||||
if strings.EqualFold(comp, dir) {
|
||||
return
|
||||
}
|
||||
// Remove any other prior paths to our temp dir
|
||||
if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) {
|
||||
pathComponents[i] = comp
|
||||
i++
|
||||
}
|
||||
}
|
||||
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
|
||||
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
|
||||
os.Setenv("PATH", newPath)
|
||||
}
|
||||
// linux and darwin rely on rpath
|
||||
}
|
@@ -27,46 +27,46 @@ struct dynamic_llama_server {
|
||||
void (*llama_server_release_json_resp)(char **json_resp);
|
||||
};
|
||||
|
||||
void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
void dyn_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
// No good way to call C function pointers from Go so inline the indirection
|
||||
void dynamic_shim_llama_server_init(struct dynamic_llama_server s,
|
||||
void dyn_llama_server_init(struct dynamic_llama_server s,
|
||||
ext_server_params_t *sparams,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dynamic_shim_llama_server_start(struct dynamic_llama_server s);
|
||||
void dyn_llama_server_start(struct dynamic_llama_server s);
|
||||
|
||||
void dynamic_shim_llama_server_stop(struct dynamic_llama_server s);
|
||||
void dyn_llama_server_stop(struct dynamic_llama_server s);
|
||||
|
||||
void dynamic_shim_llama_server_completion(struct dynamic_llama_server s,
|
||||
void dyn_llama_server_completion(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
ext_server_resp_t *resp);
|
||||
|
||||
void dynamic_shim_llama_server_completion_next_result(
|
||||
void dyn_llama_server_completion_next_result(
|
||||
struct dynamic_llama_server s, const int task_id,
|
||||
ext_server_task_result_t *result);
|
||||
|
||||
void dynamic_shim_llama_server_completion_cancel(struct dynamic_llama_server s,
|
||||
void dyn_llama_server_completion_cancel(struct dynamic_llama_server s,
|
||||
const int task_id,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dynamic_shim_llama_server_release_task_result(
|
||||
void dyn_llama_server_release_task_result(
|
||||
struct dynamic_llama_server s, ext_server_task_result_t *result);
|
||||
|
||||
void dynamic_shim_llama_server_tokenize(struct dynamic_llama_server s,
|
||||
void dyn_llama_server_tokenize(struct dynamic_llama_server s,
|
||||
const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dynamic_shim_llama_server_detokenize(struct dynamic_llama_server s,
|
||||
void dyn_llama_server_detokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dynamic_shim_llama_server_embedding(struct dynamic_llama_server s,
|
||||
void dyn_llama_server_embedding(struct dynamic_llama_server s,
|
||||
const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void dynamic_shim_llama_server_release_json_resp(struct dynamic_llama_server s,
|
||||
void dyn_llama_server_release_json_resp(struct dynamic_llama_server s,
|
||||
char **json_resp);
|
||||
|
||||
#ifdef __cplusplus
|
25
llm/ext_server/CMakeLists.txt
Normal file
25
llm/ext_server/CMakeLists.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
# Ollama specific CMakefile to include in llama.cpp/examples/server
|
||||
|
||||
set(TARGET ext_server)
|
||||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
if (WIN32)
|
||||
add_library(${TARGET} SHARED ../../../ext_server/ext_server.cpp ../../llama.cpp)
|
||||
else()
|
||||
add_library(${TARGET} STATIC ../../../ext_server/ext_server.cpp ../../llama.cpp)
|
||||
endif()
|
||||
target_include_directories(${TARGET} PRIVATE ../../common)
|
||||
target_include_directories(${TARGET} PRIVATE ../..)
|
||||
target_include_directories(${TARGET} PRIVATE ../../..)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_SERVER_LIBRARY=1)
|
||||
target_link_libraries(${TARGET} PRIVATE ggml llava common )
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>)
|
||||
install(TARGETS ext_server LIBRARY)
|
||||
|
||||
if (CUDAToolkit_FOUND)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
if (WIN32)
|
||||
target_link_libraries(${TARGET} PRIVATE nvml)
|
||||
endif()
|
||||
endif()
|
18
llm/ext_server/README.md
Normal file
18
llm/ext_server/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Extern C Server
|
||||
|
||||
This directory contains a thin facade we layer on top of the Llama.cpp server to
|
||||
expose `extern C` interfaces to access the functionality through direct API
|
||||
calls in-process. The llama.cpp code uses compile time macros to configure GPU
|
||||
type along with other settings. During the `go generate ./...` execution, the
|
||||
build will generate one or more copies of the llama.cpp `extern C` server based
|
||||
on what GPU libraries are detected to support multiple GPU types as well as CPU
|
||||
only support. The Ollama go build then embeds these different servers to support
|
||||
different GPUs and settings at runtime.
|
||||
|
||||
If you are making changes to the code in this directory, make sure to disable
|
||||
caching during your go build to ensure you pick up your changes. A typical
|
||||
iteration cycle from the top of the source tree looks like:
|
||||
|
||||
```
|
||||
go generate ./... && go build -a .
|
||||
```
|
@@ -3,21 +3,44 @@
|
||||
// Necessary evil since the server types are not defined in a header
|
||||
#include "server.cpp"
|
||||
|
||||
// Low level API access to verify GPU access
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// for rocblas_initialize()
|
||||
#include "rocblas/rocblas.h"
|
||||
#endif // __HIP_PLATFORM_AMD__
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#else
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
#endif // defined(GGML_USE_HIPBLAS)
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
// Expose the llama server as a callable extern "C" API
|
||||
llama_server_context *llama = NULL;
|
||||
std::atomic<bool> ext_server_running(false);
|
||||
std::thread ext_server_thread;
|
||||
|
||||
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
||||
#if SERVER_VERBOSE != 1
|
||||
log_disable();
|
||||
#endif
|
||||
assert(err != NULL && sparams != NULL);
|
||||
log_set_target(stderr);
|
||||
if (!sparams->verbose_logging) {
|
||||
server_verbose = true;
|
||||
log_disable();
|
||||
}
|
||||
|
||||
LOG_TEE("system info: %s\n", llama_print_system_info());
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama = new llama_server_context;
|
||||
log_set_target(stdout);
|
||||
gpt_params params;
|
||||
params.n_ctx = sparams->n_ctx;
|
||||
params.n_batch = sparams->n_batch;
|
||||
@@ -46,15 +69,31 @@ void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
||||
params.model = sparams->model;
|
||||
}
|
||||
|
||||
for (ext_server_lora_adapter *la = sparams->lora_adapters; la != NULL;
|
||||
la = la->next) {
|
||||
params.lora_adapter.push_back(std::make_tuple(la->adapter, la->scale));
|
||||
if (sparams->lora_adapters != NULL) {
|
||||
for (ext_server_lora_adapter *la = sparams->lora_adapters; la != NULL;
|
||||
la = la->next) {
|
||||
params.lora_adapter.push_back(std::make_tuple(la->adapter, la->scale));
|
||||
}
|
||||
|
||||
params.use_mmap = false;
|
||||
}
|
||||
|
||||
if (sparams->mmproj != NULL) {
|
||||
params.mmproj = std::string(sparams->mmproj);
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
// Before attempting to init the backend which will assert on error, verify the CUDA/ROCM GPU is accessible
|
||||
LOG_TEE("Performing pre-initialization of GPU\n");
|
||||
int id;
|
||||
cudaError_t cudaErr = cudaGetDevice(&id);
|
||||
if (cudaErr != cudaSuccess) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unable to init GPU: %s", cudaGetErrorString(cudaErr));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
// load the model
|
||||
@@ -83,18 +122,23 @@ void llama_server_start() {
|
||||
assert(llama != NULL);
|
||||
// TODO mutex to protect thread creation
|
||||
ext_server_thread = std::thread([&]() {
|
||||
ext_server_running = true;
|
||||
try {
|
||||
LOG_TEE("llama server main loop starting\n");
|
||||
ggml_time_init();
|
||||
while (ext_server_running.load()) {
|
||||
if (!llama->update_slots()) {
|
||||
LOG_TEE(
|
||||
"unexpected error in llama server update_slots - exiting main "
|
||||
"loop\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
llama->queue_tasks.on_new_task(std::bind(
|
||||
&llama_server_context::process_single_task, llama, std::placeholders::_1));
|
||||
llama->queue_tasks.on_finish_multitask(std::bind(
|
||||
&llama_server_context::on_finish_multitask, llama, std::placeholders::_1));
|
||||
llama->queue_tasks.on_all_tasks_finished(std::bind(
|
||||
&llama_server_context::run_on_all_tasks_finished, llama));
|
||||
llama->queue_results.on_multitask_update(std::bind(
|
||||
&llama_server_queue::update_multitask,
|
||||
&llama->queue_tasks,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
llama->queue_tasks.start_loop();
|
||||
} catch (std::exception &e) {
|
||||
LOG_TEE("caught exception in llama server main loop: %s\n", e.what());
|
||||
} catch (...) {
|
||||
@@ -107,9 +151,10 @@ void llama_server_start() {
|
||||
|
||||
void llama_server_stop() {
|
||||
assert(llama != NULL);
|
||||
// TODO - too verbose, remove once things are solid
|
||||
LOG_TEE("requesting llama server shutdown\n");
|
||||
ext_server_running = false;
|
||||
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
|
||||
// This may take a while for any pending tasks to drain
|
||||
// TODO - consider a timeout to cancel tasks if it's taking too long
|
||||
llama->queue_tasks.terminate();
|
||||
ext_server_thread.join();
|
||||
delete llama;
|
||||
llama = NULL;
|
||||
@@ -122,7 +167,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
|
||||
resp->msg[0] = '\0';
|
||||
try {
|
||||
json data = json::parse(json_req);
|
||||
resp->id = llama->request_completion(data, false, false, -1);
|
||||
resp->id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(resp->id);
|
||||
llama->request_completion(resp->id, data, false, false, -1);
|
||||
} catch (std::exception &e) {
|
||||
snprintf(resp->msg, resp->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
@@ -140,16 +187,22 @@ void llama_server_completion_next_result(const int task_id,
|
||||
resp->json_resp = NULL;
|
||||
std::string result_json;
|
||||
try {
|
||||
task_result result = llama->next_result(task_id);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
result_json =
|
||||
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
resp->id = result.id;
|
||||
resp->stop = result.stop;
|
||||
resp->error = result.error;
|
||||
if (result.error) {
|
||||
LOG_TEE("next result cancel on error\n");
|
||||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting tak ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} else if (result.stop) {
|
||||
LOG_TEE("next result cancel on stop\n");
|
||||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
resp->error = true;
|
||||
@@ -180,6 +233,7 @@ void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err) {
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama->request_cancel(task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
@@ -264,13 +318,15 @@ void llama_server_embedding(const char *json_req, char **json_resp,
|
||||
} else {
|
||||
prompt = "";
|
||||
}
|
||||
const int task_id = llama->request_completion(
|
||||
{{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||
task_result result = llama->next_result(task_id);
|
||||
const int task_id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(task_id);
|
||||
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
std::string result_json = result.result_json.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
@@ -45,6 +45,7 @@ typedef struct ext_server_params {
|
||||
bool embedding; // get only sentence embedding
|
||||
ext_server_lora_adapter_t *lora_adapters;
|
||||
char *mmproj;
|
||||
bool verbose_logging; // Enable verbose logging of the server
|
||||
} ext_server_params_t;
|
||||
|
||||
typedef struct ext_server_task_result {
|
@@ -1,80 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package llm
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "ext_server.h"
|
||||
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
type llamaExtServer struct {
|
||||
api.Options
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t) {
|
||||
C.llama_server_init(sparams, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_start() {
|
||||
C.llama_server_start()
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_stop() {
|
||||
C.llama_server_stop()
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t) {
|
||||
C.llama_server_completion(json_req, resp)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t) {
|
||||
C.llama_server_completion_next_result(task_id, resp)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t) {
|
||||
C.llama_server_completion_cancel(task_id, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_release_task_result(result *C.ext_server_task_result_t) {
|
||||
C.llama_server_release_task_result(result)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.llama_server_tokenize(json_req, json_resp, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.llama_server_detokenize(json_req, json_resp, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.llama_server_embedding(json_req, json_resp, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_release_json_resp(json_resp **C.char) {
|
||||
C.llama_server_release_json_resp(json_resp)
|
||||
}
|
||||
|
||||
func newDefaultExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
server := &llamaExtServer{opts}
|
||||
return newExtServer(server, model, adapters, projectors, numLayers, opts)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
||||
return predict(ctx, llm, pred, fn)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
return encode(llm, ctx, prompt)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return decode(llm, ctx, tokens)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
return embedding(llm, ctx, input)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Close() {
|
||||
close(llm)
|
||||
}
|
@@ -1,12 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func newDefaultExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
// On windows we always load the llama.cpp libraries dynamically to avoid startup DLL dependencies
|
||||
// This ensures we can update the PATH at runtime to get everything loaded
|
||||
|
||||
return newDynamicShimExtServer(AvailableShims["cpu"], model, adapters, projectors, numLayers, opts)
|
||||
}
|
125
llm/generate/gen_common.sh
Normal file
125
llm/generate/gen_common.sh
Normal file
@@ -0,0 +1,125 @@
|
||||
# common logic accross linux and darwin
|
||||
|
||||
init_vars() {
|
||||
case "${GOARCH}" in
|
||||
"amd64")
|
||||
ARCH="x86_64"
|
||||
;;
|
||||
"arm64")
|
||||
ARCH="arm64"
|
||||
;;
|
||||
*)
|
||||
ARCH=$(uname -m | sed -e "s/aarch64/arm64/g")
|
||||
esac
|
||||
|
||||
LLAMACPP_DIR=../llama.cpp
|
||||
CMAKE_DEFS=""
|
||||
CMAKE_TARGETS="--target ext_server"
|
||||
if echo "${CGO_CFLAGS}" | grep -- '-g' >/dev/null; then
|
||||
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_VERBOSE_MAKEFILE=on -DLLAMA_GPROF=on -DLLAMA_SERVER_VERBOSE=on ${CMAKE_DEFS}"
|
||||
else
|
||||
# TODO - add additional optimization flags...
|
||||
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off ${CMAKE_DEFS}"
|
||||
fi
|
||||
case $(uname -s) in
|
||||
"Darwin")
|
||||
LIB_EXT="dylib"
|
||||
WHOLE_ARCHIVE="-Wl,-force_load"
|
||||
NO_WHOLE_ARCHIVE=""
|
||||
GCC_ARCH="-arch ${ARCH}"
|
||||
;;
|
||||
"Linux")
|
||||
LIB_EXT="so"
|
||||
WHOLE_ARCHIVE="-Wl,--whole-archive"
|
||||
NO_WHOLE_ARCHIVE="-Wl,--no-whole-archive"
|
||||
|
||||
# Cross compiling not supported on linux - Use docker
|
||||
GCC_ARCH=""
|
||||
;;
|
||||
*)
|
||||
;;
|
||||
esac
|
||||
if [ -z "${CMAKE_CUDA_ARCHITECTURES}" ] ; then
|
||||
CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
|
||||
fi
|
||||
}
|
||||
|
||||
git_module_setup() {
|
||||
if [ -n "${OLLAMA_SKIP_PATCHING}" ]; then
|
||||
echo "Skipping submodule initialization"
|
||||
return
|
||||
fi
|
||||
# Make sure the tree is clean after the directory moves
|
||||
if [ -d "${LLAMACPP_DIR}/gguf" ]; then
|
||||
echo "Cleaning up old submodule"
|
||||
rm -rf ${LLAMACPP_DIR}
|
||||
fi
|
||||
git submodule init
|
||||
git submodule update --force ${LLAMACPP_DIR}
|
||||
|
||||
}
|
||||
|
||||
apply_patches() {
|
||||
# Wire up our CMakefile
|
||||
if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then
|
||||
echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt
|
||||
fi
|
||||
|
||||
if [ -n "$(ls -A ../patches/*.diff)" ]; then
|
||||
# apply temporary patches until fix is upstream
|
||||
for patch in ../patches/*.diff; do
|
||||
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
|
||||
(cd ${LLAMACPP_DIR}; git checkout ${file})
|
||||
done
|
||||
done
|
||||
for patch in ../patches/*.diff; do
|
||||
(cd ${LLAMACPP_DIR} && git apply ${patch})
|
||||
done
|
||||
fi
|
||||
|
||||
# Avoid duplicate main symbols when we link into the cgo binary
|
||||
sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp &&
|
||||
mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp
|
||||
}
|
||||
|
||||
build() {
|
||||
cmake -S ${LLAMACPP_DIR} -B ${BUILD_DIR} ${CMAKE_DEFS}
|
||||
cmake --build ${BUILD_DIR} ${CMAKE_TARGETS} -j8
|
||||
mkdir -p ${BUILD_DIR}/lib/
|
||||
g++ -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.${LIB_EXT} \
|
||||
${GCC_ARCH} \
|
||||
${WHOLE_ARCHIVE} ${BUILD_DIR}/examples/server/libext_server.a ${NO_WHOLE_ARCHIVE} \
|
||||
${BUILD_DIR}/common/libcommon.a \
|
||||
${BUILD_DIR}/libllama.a \
|
||||
-Wl,-rpath,\$ORIGIN \
|
||||
-lpthread -ldl -lm \
|
||||
${EXTRA_LIBS}
|
||||
}
|
||||
|
||||
compress_libs() {
|
||||
echo "Compressing payloads to reduce overall binary size..."
|
||||
pids=""
|
||||
rm -rf ${BUILD_DIR}/lib/*.${LIB_EXT}*.gz
|
||||
for lib in ${BUILD_DIR}/lib/*.${LIB_EXT}* ; do
|
||||
gzip --best -f ${lib} &
|
||||
pids+=" $!"
|
||||
done
|
||||
echo
|
||||
for pid in ${pids}; do
|
||||
wait $pid
|
||||
done
|
||||
echo "Finished compression"
|
||||
}
|
||||
|
||||
# Keep the local tree clean after we're done with the build
|
||||
cleanup() {
|
||||
(cd ${LLAMACPP_DIR}/examples/server/ && git checkout CMakeLists.txt server.cpp)
|
||||
|
||||
if [ -n "$(ls -A ../patches/*.diff)" ]; then
|
||||
for patch in ../patches/*.diff; do
|
||||
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
|
||||
(cd ${LLAMACPP_DIR}; git checkout ${file})
|
||||
done
|
||||
done
|
||||
fi
|
||||
}
|
77
llm/generate/gen_darwin.sh
Executable file
77
llm/generate/gen_darwin.sh
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
# This script is intended to run inside the go generate
|
||||
# working directory must be ./llm/generate/
|
||||
|
||||
# TODO - add hardening to detect missing tools (cmake, etc.)
|
||||
|
||||
set -ex
|
||||
set -o pipefail
|
||||
echo "Starting darwin generate script"
|
||||
source $(dirname $0)/gen_common.sh
|
||||
init_vars
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
sign() {
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp --deep --options=runtime --sign "$APPLE_IDENTITY" --identifier ai.ollama.ollama $1
|
||||
fi
|
||||
}
|
||||
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin"
|
||||
|
||||
case "${GOARCH}" in
|
||||
"amd64")
|
||||
COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_NATIVE=off"
|
||||
|
||||
#
|
||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||
#
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu/lib/libext_server.dylib
|
||||
compress_libs
|
||||
|
||||
#
|
||||
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
|
||||
# Approximately 400% faster than LCD on same CPU
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx/lib/libext_server.dylib
|
||||
compress_libs
|
||||
|
||||
#
|
||||
# ~2013 CPU Dynamic library
|
||||
# Approximately 10% faster than AVX on same CPU
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2"
|
||||
echo "Building AVX2 CPU"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2/lib/libext_server.dylib
|
||||
compress_libs
|
||||
;;
|
||||
"arm64")
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/metal"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/metal/lib/libext_server.dylib
|
||||
compress_libs
|
||||
;;
|
||||
*)
|
||||
echo "GOARCH must be set"
|
||||
echo "this script is meant to be run from within go generate"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
cleanup
|
190
llm/generate/gen_linux.sh
Executable file
190
llm/generate/gen_linux.sh
Executable file
@@ -0,0 +1,190 @@
|
||||
#!/bin/bash
|
||||
# This script is intended to run inside the go generate
|
||||
# working directory must be llm/generate/
|
||||
|
||||
# First we build one or more CPU based LLM libraries
|
||||
#
|
||||
# Then if we detect CUDA, we build a CUDA dynamic library, and carry the required
|
||||
# library dependencies
|
||||
#
|
||||
# Then if we detect ROCm, we build a dynamically loaded ROCm lib. The ROCM
|
||||
# libraries are quite large, and also dynamically load data files at runtime
|
||||
# which in turn are large, so we don't attempt to cary them as payload
|
||||
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
# See https://llvm.org/docs/AMDGPUUsage.html#processors for reference
|
||||
amdGPUs() {
|
||||
if [ -n "${AMDGPU_TARGETS}" ]; then
|
||||
echo "${AMDGPU_TARGETS}"
|
||||
return
|
||||
fi
|
||||
GPU_LIST=(
|
||||
"gfx803"
|
||||
"gfx900"
|
||||
"gfx906:xnack-"
|
||||
"gfx908:xnack-"
|
||||
"gfx90a:xnack+"
|
||||
"gfx90a:xnack-"
|
||||
"gfx1010"
|
||||
"gfx1012"
|
||||
"gfx1030"
|
||||
"gfx1100"
|
||||
"gfx1101"
|
||||
"gfx1102"
|
||||
)
|
||||
(
|
||||
IFS=$';'
|
||||
echo "'${GPU_LIST[*]}'"
|
||||
)
|
||||
}
|
||||
|
||||
echo "Starting linux generate script"
|
||||
if [ -z "${CUDACXX}" ]; then
|
||||
if [ -x /usr/local/cuda/bin/nvcc ]; then
|
||||
export CUDACXX=/usr/local/cuda/bin/nvcc
|
||||
else
|
||||
# Try the default location in case it exists
|
||||
export CUDACXX=$(command -v nvcc)
|
||||
fi
|
||||
fi
|
||||
COMMON_CMAKE_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off"
|
||||
source $(dirname $0)/gen_common.sh
|
||||
init_vars
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
# Users building from source can tune the exact flags we pass to cmake for configuring
|
||||
# llama.cpp, and we'll build only 1 CPU variant in that case as the default.
|
||||
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
|
||||
echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
|
||||
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
|
||||
echo "Building custom CPU"
|
||||
build
|
||||
compress_libs
|
||||
else
|
||||
# Darwin Rosetta x86 emulation does NOT support AVX, AVX2, AVX512
|
||||
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
||||
# -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
|
||||
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
|
||||
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
|
||||
# Note: the following seem to yield slower results than AVX2 - ymmv
|
||||
# -DLLAMA_AVX512 -- 2017 Intel Skylake and High End DeskTop (HEDT)
|
||||
# -DLLAMA_AVX512_VBMI -- 2018 Intel Cannon Lake
|
||||
# -DLLAMA_AVX512_VNNI -- 2021 Intel Alder Lake
|
||||
|
||||
COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off"
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
|
||||
#
|
||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||
#
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
compress_libs
|
||||
fi
|
||||
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu_avx" ]; then
|
||||
#
|
||||
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
|
||||
# Approximately 400% faster than LCD on same CPU
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
compress_libs
|
||||
fi
|
||||
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu_avx2" ]; then
|
||||
#
|
||||
# ~2013 CPU Dynamic library
|
||||
# Approximately 10% faster than AVX on same CPU
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx2"
|
||||
echo "Building AVX2 CPU"
|
||||
build
|
||||
compress_libs
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Skipping CPU generation step as requested"
|
||||
fi
|
||||
|
||||
# If needed, look for the default CUDA toolkit location
|
||||
if [ -z "${CUDA_LIB_DIR}" ] && [ -d /usr/local/cuda/lib64 ]; then
|
||||
CUDA_LIB_DIR=/usr/local/cuda/lib64
|
||||
fi
|
||||
|
||||
# If needed, look for CUDA on Arch Linux
|
||||
if [ -z "${CUDA_LIB_DIR}" ] && [ -d /opt/cuda/targets/x86_64-linux/lib ]; then
|
||||
CUDA_LIB_DIR=/opt/cuda/targets/x86_64-linux/lib
|
||||
fi
|
||||
|
||||
if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
echo "CUDA libraries detected - building dynamic CUDA library"
|
||||
init_vars
|
||||
CUDA_MAJOR=$(ls "${CUDA_LIB_DIR}"/libcudart.so.* | head -1 | cut -f3 -d. || true)
|
||||
if [ -n "${CUDA_MAJOR}" ]; then
|
||||
CUDA_VARIANT=_v${CUDA_MAJOR}
|
||||
fi
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
|
||||
build
|
||||
|
||||
# Cary the CUDA libs as payloads to help reduce dependency burden on users
|
||||
#
|
||||
# TODO - in the future we may shift to packaging these separately and conditionally
|
||||
# downloading them in the install script.
|
||||
DEPS="$(ldd ${BUILD_DIR}/lib/libext_server.so )"
|
||||
for lib in libcudart.so libcublas.so libcublasLt.so ; do
|
||||
DEP=$(echo "${DEPS}" | grep ${lib} | cut -f1 -d' ' | xargs || true)
|
||||
if [ -n "${DEP}" -a -e "${CUDA_LIB_DIR}/${DEP}" ]; then
|
||||
cp "${CUDA_LIB_DIR}/${DEP}" "${BUILD_DIR}/lib/"
|
||||
elif [ -e "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" ]; then
|
||||
cp "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" "${BUILD_DIR}/lib/"
|
||||
else
|
||||
cp -d "${CUDA_LIB_DIR}/${lib}*" "${BUILD_DIR}/lib/"
|
||||
fi
|
||||
done
|
||||
compress_libs
|
||||
|
||||
fi
|
||||
|
||||
if [ -z "${ROCM_PATH}" ]; then
|
||||
# Try the default location in case it exists
|
||||
ROCM_PATH=/opt/rocm
|
||||
fi
|
||||
|
||||
if [ -z "${CLBlast_DIR}" ]; then
|
||||
# Try the default location in case it exists
|
||||
if [ -d /usr/lib/cmake/CLBlast ]; then
|
||||
export CLBlast_DIR=/usr/lib/cmake/CLBlast
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -d "${ROCM_PATH}" ]; then
|
||||
echo "ROCm libraries detected - building dynamic ROCm library"
|
||||
if [ -f ${ROCM_PATH}/lib/librocm_smi64.so.? ]; then
|
||||
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocm_smi64.so.? | cut -f3 -d. || true)
|
||||
fi
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/rocm${ROCM_VARIANT}"
|
||||
EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,${ROCM_PATH}/lib,-rpath,/opt/amdgpu/lib/x86_64-linux-gnu/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
|
||||
build
|
||||
|
||||
# Note: the ROCM libs and runtime library files are too large to embed, so we depend on
|
||||
# them being present at runtime on the host
|
||||
compress_libs
|
||||
fi
|
||||
|
||||
cleanup
|
175
llm/generate/gen_windows.ps1
Normal file
175
llm/generate/gen_windows.ps1
Normal file
@@ -0,0 +1,175 @@
|
||||
#!powershell
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
function init_vars {
|
||||
$script:llamacppDir = "../llama.cpp"
|
||||
$script:cmakeDefs = @("-DBUILD_SHARED_LIBS=on", "-DLLAMA_NATIVE=off", "-A","x64")
|
||||
$script:cmakeTargets = @("ext_server")
|
||||
$script:ARCH = "amd64" # arm not yet supported.
|
||||
if ($env:CGO_CFLAGS -contains "-g") {
|
||||
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on")
|
||||
$script:config = "RelWithDebInfo"
|
||||
} else {
|
||||
$script:cmakeDefs += @("-DLLAMA_SERVER_VERBOSE=off")
|
||||
$script:config = "Release"
|
||||
}
|
||||
# Try to find the CUDA dir
|
||||
if ($env:CUDA_LIB_DIR -eq $null) {
|
||||
$d=(get-command -ea 'silentlycontinue' nvcc).path
|
||||
if ($d -ne $null) {
|
||||
$script:CUDA_LIB_DIR=($d| split-path -parent)
|
||||
}
|
||||
} else {
|
||||
$script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR
|
||||
}
|
||||
$script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
|
||||
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
|
||||
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
|
||||
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
|
||||
} else {
|
||||
$script:CMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES
|
||||
}
|
||||
}
|
||||
|
||||
function git_module_setup {
|
||||
# TODO add flags to skip the init/patch logic to make it easier to mod llama.cpp code in-repo
|
||||
& git submodule init
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& git submodule update --force "${script:llamacppDir}"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function apply_patches {
|
||||
# Wire up our CMakefile
|
||||
if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
|
||||
Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama'
|
||||
}
|
||||
|
||||
# Apply temporary patches until fix is upstream
|
||||
$patches = Get-ChildItem "../patches/*.diff"
|
||||
foreach ($patch in $patches) {
|
||||
# Extract file paths from the patch file
|
||||
$filePaths = Get-Content $patch.FullName | Where-Object { $_ -match '^\+\+\+ ' } | ForEach-Object {
|
||||
$parts = $_ -split ' '
|
||||
($parts[1] -split '/', 2)[1]
|
||||
}
|
||||
|
||||
# Checkout each file
|
||||
foreach ($file in $filePaths) {
|
||||
Set-Location -Path ${script:llamacppDir}
|
||||
git checkout $file
|
||||
}
|
||||
}
|
||||
|
||||
# Apply each patch
|
||||
foreach ($patch in $patches) {
|
||||
Set-Location -Path ${script:llamacppDir}
|
||||
git apply $patch.FullName
|
||||
}
|
||||
|
||||
# Avoid duplicate main symbols when we link into the cgo binary
|
||||
$content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp"
|
||||
$content = $content -replace 'int main\(', 'int __main('
|
||||
Set-Content -Path "${script:llamacppDir}/examples/server/server.cpp" -Value $content
|
||||
}
|
||||
|
||||
function build {
|
||||
write-host "generating config with: cmake -S ${script:llamacppDir} -B $script:buildDir $script:cmakeDefs"
|
||||
& cmake --version
|
||||
& cmake -S "${script:llamacppDir}" -B $script:buildDir $script:cmakeDefs
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
write-host "building with: cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })"
|
||||
& cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function install {
|
||||
rm -ea 0 -recurse -force -path "${script:buildDir}/lib"
|
||||
md "${script:buildDir}/lib" -ea 0 > $null
|
||||
cp "${script:buildDir}/bin/${script:config}/ext_server.dll" "${script:buildDir}/lib"
|
||||
cp "${script:buildDir}/bin/${script:config}/llama.dll" "${script:buildDir}/lib"
|
||||
|
||||
# Display the dll dependencies in the build log
|
||||
if ($script:DUMPBIN -ne $null) {
|
||||
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/${script:config}/ext_server.dll" | select-string ".dll"
|
||||
}
|
||||
}
|
||||
|
||||
function compress_libs {
|
||||
if ($script:GZIP -eq $null) {
|
||||
write-host "gzip not installed, not compressing files"
|
||||
return
|
||||
}
|
||||
write-host "Compressing dlls..."
|
||||
$libs = dir "${script:buildDir}/lib/*.dll"
|
||||
foreach ($file in $libs) {
|
||||
& "$script:GZIP" --best -f $file
|
||||
}
|
||||
}
|
||||
|
||||
function cleanup {
|
||||
Set-Location "${script:llamacppDir}/examples/server"
|
||||
git checkout CMakeLists.txt server.cpp
|
||||
}
|
||||
|
||||
init_vars
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
||||
# -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
|
||||
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
|
||||
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
|
||||
|
||||
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on", "-DLLAMA_NATIVE=off")
|
||||
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu"
|
||||
write-host "Building LCD CPU"
|
||||
build
|
||||
install
|
||||
compress_libs
|
||||
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu_avx"
|
||||
write-host "Building AVX CPU"
|
||||
build
|
||||
install
|
||||
compress_libs
|
||||
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu_avx2"
|
||||
write-host "Building AVX2 CPU"
|
||||
build
|
||||
install
|
||||
compress_libs
|
||||
|
||||
if ($null -ne $script:CUDA_LIB_DIR) {
|
||||
# Then build cuda as a dynamically loaded library
|
||||
$nvcc = (get-command -ea 'silentlycontinue' nvcc)
|
||||
if ($null -ne $nvcc) {
|
||||
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
|
||||
}
|
||||
if ($null -ne $script:CUDA_VERSION) {
|
||||
$script:CUDA_VARIANT="_"+$script:CUDA_VERSION
|
||||
}
|
||||
init_vars
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
build
|
||||
install
|
||||
cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib"
|
||||
cp "${script:CUDA_LIB_DIR}/cublas64_*.dll" "${script:buildDir}/lib"
|
||||
cp "${script:CUDA_LIB_DIR}/cublasLt64_*.dll" "${script:buildDir}/lib"
|
||||
compress_libs
|
||||
}
|
||||
# TODO - actually implement ROCm support on windows
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/rocm"
|
||||
|
||||
rm -ea 0 -recurse -force -path "${script:buildDir}/lib"
|
||||
md "${script:buildDir}/lib" -ea 0 > $null
|
||||
echo $null >> "${script:buildDir}/lib/.generated"
|
||||
|
||||
cleanup
|
||||
write-host "`ngo generate completed"
|
@@ -1,3 +1,3 @@
|
||||
package llm
|
||||
package generate
|
||||
|
||||
//go:generate sh ./gen_darwin.sh
|
@@ -1,3 +1,3 @@
|
||||
package llm
|
||||
package generate
|
||||
|
||||
//go:generate bash ./gen_linux.sh
|
@@ -1,3 +1,3 @@
|
||||
package llm
|
||||
package generate
|
||||
|
||||
//go:generate powershell -ExecutionPolicy Bypass -File ./gen_windows.ps1
|
13
llm/ggml.go
13
llm/ggml.go
@@ -78,7 +78,12 @@ type model interface {
|
||||
ModelFamily() string
|
||||
ModelType() string
|
||||
FileType() string
|
||||
NumLayers() int64
|
||||
NumLayers() uint32
|
||||
NumGQA() uint32
|
||||
NumEmbed() uint32
|
||||
NumHead() uint32
|
||||
NumHeadKv() uint32
|
||||
NumCtx() uint32
|
||||
}
|
||||
|
||||
type container interface {
|
||||
@@ -94,9 +99,9 @@ func (c *containerLORA) Name() string {
|
||||
return "ggla"
|
||||
}
|
||||
|
||||
func (c *containerLORA) Decode(ro *readSeekOffset) (model, error) {
|
||||
func (c *containerLORA) Decode(rso *readSeekOffset) (model, error) {
|
||||
var version uint32
|
||||
binary.Read(ro, binary.LittleEndian, &version)
|
||||
binary.Read(rso, binary.LittleEndian, &version)
|
||||
|
||||
switch version {
|
||||
case 1:
|
||||
@@ -107,7 +112,7 @@ func (c *containerLORA) Decode(ro *readSeekOffset) (model, error) {
|
||||
c.version = version
|
||||
|
||||
// remaining file contents aren't decoded
|
||||
ro.Seek(0, io.SeekEnd)
|
||||
rso.Seek(0, io.SeekEnd)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
165
llm/gguf.go
165
llm/gguf.go
@@ -69,12 +69,65 @@ type tensor struct {
|
||||
name string
|
||||
kind uint32
|
||||
offset uint64
|
||||
size uint64
|
||||
|
||||
// shape is the number of elements in each dimension
|
||||
shape [4]uint64
|
||||
}
|
||||
|
||||
func (t tensor) blockSize() uint64 {
|
||||
switch {
|
||||
case t.kind < 2:
|
||||
return 1
|
||||
case t.kind < 10:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
func (t tensor) typeSize() uint64 {
|
||||
blockSize := t.blockSize()
|
||||
|
||||
switch t.kind {
|
||||
case 0: // FP32
|
||||
return 4
|
||||
case 1: // FP16
|
||||
return 2
|
||||
case 2: // Q4_0
|
||||
return 2 + blockSize/2
|
||||
case 3: // Q4_1
|
||||
return 2 + 2 + blockSize/2
|
||||
case 6: // Q5_0
|
||||
return 2 + 4 + blockSize/2
|
||||
case 7: // Q5_1
|
||||
return 2 + 2 + 4 + blockSize/2
|
||||
case 8: // Q8_0
|
||||
return 2 + blockSize
|
||||
case 9: // Q8_1
|
||||
return 4 + 4 + blockSize
|
||||
case 10: // Q2_K
|
||||
return blockSize/16 + blockSize/4 + 2 + 2
|
||||
case 11: // Q3_K
|
||||
return blockSize/8 + blockSize/4 + 12 + 2
|
||||
case 12: // Q4_K
|
||||
return 2 + 2 + 12 + blockSize/2
|
||||
case 13: // Q5_K
|
||||
return 2 + 2 + 12 + blockSize/8 + blockSize/2
|
||||
case 14: // Q6_K
|
||||
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (t tensor) parameters() uint64 {
|
||||
return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3]
|
||||
}
|
||||
|
||||
func (t tensor) size() uint64 {
|
||||
return t.parameters() * t.typeSize() / t.blockSize()
|
||||
}
|
||||
|
||||
type ggufModel struct {
|
||||
*containerGGUF
|
||||
|
||||
@@ -201,61 +254,15 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
|
||||
shape[i] = llm.readU64(rso)
|
||||
}
|
||||
|
||||
kind := llm.readU32(rso)
|
||||
offset := llm.readU64(rso)
|
||||
|
||||
var blockSize uint64
|
||||
switch {
|
||||
case kind < 2:
|
||||
blockSize = 1
|
||||
case kind < 10:
|
||||
blockSize = 32
|
||||
default:
|
||||
blockSize = 256
|
||||
}
|
||||
|
||||
var typeSize uint64
|
||||
switch kind {
|
||||
case 0: // FP32
|
||||
typeSize = 4
|
||||
case 1: // FP16
|
||||
typeSize = 2
|
||||
case 2: // Q4_0
|
||||
typeSize = 2 + blockSize/2
|
||||
case 3: // Q4_1
|
||||
typeSize = 2 + 2 + blockSize/2
|
||||
case 6: // Q5_0
|
||||
typeSize = 2 + 4 + blockSize/2
|
||||
case 7: // Q5_1
|
||||
typeSize = 2 + 2 + 4 + blockSize/2
|
||||
case 8: // Q8_0
|
||||
typeSize = 2 + blockSize
|
||||
case 9: // Q8_1
|
||||
typeSize = 4 + 4 + blockSize
|
||||
case 10: // Q2_K
|
||||
typeSize = blockSize/16 + blockSize/4 + 2 + 2
|
||||
case 11: // Q3_K
|
||||
typeSize = blockSize/8 + blockSize/4 + 12 + 2
|
||||
case 12: // Q4_K
|
||||
typeSize = 2 + 2 + 12 + blockSize/2
|
||||
case 13: // Q5_K
|
||||
typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
|
||||
case 14: // Q6_K
|
||||
typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||
}
|
||||
|
||||
parameters := shape[0] * shape[1] * shape[2] * shape[3]
|
||||
size := parameters * typeSize / blockSize
|
||||
|
||||
llm.tensors = append(llm.tensors, tensor{
|
||||
tensor := tensor{
|
||||
name: name,
|
||||
kind: kind,
|
||||
offset: offset,
|
||||
size: size,
|
||||
kind: llm.readU32(rso),
|
||||
offset: llm.readU64(rso),
|
||||
shape: shape,
|
||||
})
|
||||
}
|
||||
|
||||
llm.parameters += parameters
|
||||
llm.tensors = append(llm.tensors, tensor)
|
||||
llm.parameters += tensor.parameters()
|
||||
}
|
||||
|
||||
alignment, ok := llm.kv["general.alignment"].(uint32)
|
||||
@@ -265,21 +272,65 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
|
||||
|
||||
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
|
||||
for _, tensor := range llm.tensors {
|
||||
padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
|
||||
padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
|
||||
rso.Seek(padded, io.SeekCurrent)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *ggufModel) NumLayers() int64 {
|
||||
func (llm *ggufModel) NumLayers() uint32 {
|
||||
value, exists := llm.kv[fmt.Sprintf("%s.block_count", llm.ModelFamily())]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
|
||||
v := value.(uint32)
|
||||
return int64(v)
|
||||
return value.(uint32)
|
||||
}
|
||||
|
||||
func (llm *ggufModel) NumHead() uint32 {
|
||||
value, exists := llm.kv[fmt.Sprintf("%s.attention.head_count", llm.ModelFamily())]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
|
||||
return value.(uint32)
|
||||
}
|
||||
|
||||
func (llm *ggufModel) NumEmbed() uint32 {
|
||||
value, exists := llm.kv[fmt.Sprintf("%s.embedding_length", llm.ModelFamily())]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
|
||||
return value.(uint32)
|
||||
}
|
||||
|
||||
func (llm *ggufModel) NumHeadKv() uint32 {
|
||||
value, exists := llm.kv[fmt.Sprintf("%s.attention.head_count_kv", llm.ModelFamily())]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
|
||||
return value.(uint32)
|
||||
}
|
||||
|
||||
func (llm *ggufModel) NumCtx() uint32 {
|
||||
value, exists := llm.kv[fmt.Sprintf("%s.context_length", llm.ModelFamily())]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
|
||||
return value.(uint32)
|
||||
}
|
||||
|
||||
func (llm *ggufModel) NumGQA() uint32 {
|
||||
numHeadKv := llm.NumHeadKv()
|
||||
if numHeadKv == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return llm.NumHead() / numHeadKv
|
||||
}
|
||||
|
||||
func (llm ggufModel) readU8(r io.Reader) uint8 {
|
||||
|
1
llm/llama.cpp
Submodule
1
llm/llama.cpp
Submodule
Submodule llm/llama.cpp added at d2f650cb5b
@@ -1,29 +0,0 @@
|
||||
# Ollama specific CMakefile to include in llama.cpp/examples/server
|
||||
|
||||
set(TARGET ext_server)
|
||||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
add_library(${TARGET} STATIC ../../../ext_server.cpp)
|
||||
target_include_directories(${TARGET} PRIVATE ../../common)
|
||||
target_include_directories(${TARGET} PRIVATE ../..)
|
||||
target_include_directories(${TARGET} PRIVATE ../../..)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_SERVER_LIBRARY=1)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(ext_server PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(ext_server PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||
add_library(ext_server_shared SHARED $<TARGET_OBJECTS:ext_server>)
|
||||
target_link_libraries(ext_server_shared PRIVATE ggml llama llava common ${CMAKE_THREAD_LIBS_INIT})
|
||||
install(TARGETS ext_server_shared LIBRARY)
|
||||
endif()
|
||||
|
||||
if (CUDAToolkit_FOUND)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
if (WIN32)
|
||||
target_link_libraries(ext_server_shared PRIVATE nvml)
|
||||
endif()
|
||||
endif()
|
@@ -1,53 +0,0 @@
|
||||
# common logic accross linux and darwin
|
||||
|
||||
init_vars() {
|
||||
LLAMACPP_DIR=gguf
|
||||
PATCHES="0001-Expose-callable-API-for-server.patch"
|
||||
CMAKE_DEFS=""
|
||||
CMAKE_TARGETS="--target ggml --target ggml_static --target llama --target build_info --target common --target ext_server --target llava_static"
|
||||
if echo "${CGO_CFLAGS}" | grep -- '-g' >/dev/null; then
|
||||
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_VERBOSE_MAKEFILE=on -DLLAMA_GPROF=on -DLLAMA_SERVER_VERBOSE=on"
|
||||
else
|
||||
# TODO - add additional optimization flags...
|
||||
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off"
|
||||
fi
|
||||
}
|
||||
|
||||
git_module_setup() {
|
||||
if [ -n "${OLLAMA_SKIP_PATCHING}" ]; then
|
||||
echo "Skipping submodule initialization"
|
||||
return
|
||||
fi
|
||||
git submodule init
|
||||
git submodule update --force gguf
|
||||
|
||||
}
|
||||
|
||||
apply_patches() {
|
||||
# Wire up our CMakefile
|
||||
if ! grep ollama gguf/examples/server/CMakeLists.txt; then
|
||||
echo 'include (../../../CMakeLists.txt) # ollama' >>gguf/examples/server/CMakeLists.txt
|
||||
fi
|
||||
# Avoid duplicate main symbols when we link into the cgo binary
|
||||
sed -e 's/int main(/int __main(/g' <./gguf/examples/server/server.cpp >./gguf/examples/server/server.cpp.tmp &&
|
||||
mv ./gguf/examples/server/server.cpp.tmp ./gguf/examples/server/server.cpp
|
||||
}
|
||||
|
||||
build() {
|
||||
cmake -S ${LLAMACPP_DIR} -B ${BUILD_DIR} ${CMAKE_DEFS}
|
||||
cmake --build ${BUILD_DIR} ${CMAKE_TARGETS} -j8
|
||||
}
|
||||
|
||||
install() {
|
||||
rm -rf ${BUILD_DIR}/lib
|
||||
mkdir -p ${BUILD_DIR}/lib
|
||||
cp ${BUILD_DIR}/examples/server/libext_server.a ${BUILD_DIR}/lib
|
||||
cp ${BUILD_DIR}/common/libcommon.a ${BUILD_DIR}/lib
|
||||
cp ${BUILD_DIR}/libllama.a ${BUILD_DIR}/lib
|
||||
cp ${BUILD_DIR}/libggml_static.a ${BUILD_DIR}/lib
|
||||
}
|
||||
|
||||
# Keep the local tree clean after we're done with the build
|
||||
cleanup() {
|
||||
(cd gguf/examples/server/ && git checkout CMakeLists.txt server.cpp)
|
||||
}
|
@@ -1,32 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script is intended to run inside the go generate
|
||||
# working directory must be ../llm/llama.cpp
|
||||
|
||||
# TODO - add hardening to detect missing tools (cmake, etc.)
|
||||
|
||||
set -ex
|
||||
set -o pipefail
|
||||
echo "Starting darwin generate script"
|
||||
source $(dirname $0)/gen_common.sh
|
||||
init_vars
|
||||
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="gguf/build/darwin/metal"
|
||||
case "${GOARCH}" in
|
||||
"amd64")
|
||||
CMAKE_DEFS="-DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
;;
|
||||
"arm64")
|
||||
CMAKE_DEFS="-DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 ${CMAKE_DEFS}"
|
||||
;;
|
||||
*)
|
||||
echo "GOARCH must be set"
|
||||
echo "this script is meant to be run from within go generate"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
git_module_setup
|
||||
apply_patches
|
||||
build
|
||||
install
|
||||
cleanup
|
@@ -1,116 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script is intended to run inside the go generate
|
||||
# working directory must be llm/llama.cpp
|
||||
|
||||
# First we build our default built-in library which will be linked into the CGO
|
||||
# binary as a normal dependency. This default build is CPU based.
|
||||
#
|
||||
# Then we build a CUDA dynamic library (although statically linked with the CUDA
|
||||
# library dependencies for maximum portability)
|
||||
#
|
||||
# Then if we detect ROCm, we build a dynamically loaded ROCm lib. ROCm is particularly
|
||||
# important to be a dynamic lib even if it's the only GPU library detected because
|
||||
# we can't redistribute the objectfiles but must rely on dynamic libraries at
|
||||
# runtime, which could lead the server not to start if not present.
|
||||
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
# See https://llvm.org/docs/AMDGPUUsage.html#processors for reference
|
||||
amdGPUs() {
|
||||
GPU_LIST=(
|
||||
"gfx803"
|
||||
"gfx900"
|
||||
"gfx906:xnack-"
|
||||
"gfx908:xnack-"
|
||||
"gfx90a:xnack+"
|
||||
"gfx90a:xnack-"
|
||||
"gfx1010"
|
||||
"gfx1012"
|
||||
"gfx1030"
|
||||
"gfx1100"
|
||||
"gfx1101"
|
||||
"gfx1102"
|
||||
)
|
||||
(
|
||||
IFS=$';'
|
||||
echo "'${GPU_LIST[*]}'"
|
||||
)
|
||||
}
|
||||
|
||||
echo "Starting linux generate script"
|
||||
if [ -z "${CUDACXX}" -a -x /usr/local/cuda/bin/nvcc ]; then
|
||||
export CUDACXX=/usr/local/cuda/bin/nvcc
|
||||
fi
|
||||
COMMON_CMAKE_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off"
|
||||
source $(dirname $0)/gen_common.sh
|
||||
init_vars
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
#
|
||||
# CPU first for the default library
|
||||
#
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
|
||||
BUILD_DIR="gguf/build/linux/cpu"
|
||||
|
||||
build
|
||||
install
|
||||
|
||||
# Placeholder to keep go embed happy until we start building dynamic CPU lib variants
|
||||
touch ${BUILD_DIR}/lib/dummy.so
|
||||
|
||||
if [ -d /usr/local/cuda/lib64/ ]; then
|
||||
echo "CUDA libraries detected - building dynamic CUDA library"
|
||||
init_vars
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
|
||||
BUILD_DIR="gguf/build/linux/cuda"
|
||||
CUDA_LIB_DIR=/usr/local/cuda/lib64
|
||||
build
|
||||
install
|
||||
gcc -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.so \
|
||||
-Wl,--whole-archive \
|
||||
${BUILD_DIR}/lib/libext_server.a \
|
||||
${BUILD_DIR}/lib/libcommon.a \
|
||||
${BUILD_DIR}/lib/libllama.a \
|
||||
-Wl,--no-whole-archive \
|
||||
${CUDA_LIB_DIR}/libcudart_static.a \
|
||||
${CUDA_LIB_DIR}/libcublas_static.a \
|
||||
${CUDA_LIB_DIR}/libcublasLt_static.a \
|
||||
${CUDA_LIB_DIR}/libcudadevrt.a \
|
||||
${CUDA_LIB_DIR}/libculibos.a \
|
||||
-lrt -lpthread -ldl -lstdc++ -lm
|
||||
fi
|
||||
|
||||
if [ -z "${ROCM_PATH}" ]; then
|
||||
# Try the default location in case it exists
|
||||
ROCM_PATH=/opt/rocm
|
||||
fi
|
||||
|
||||
if [ -z "${CLBlast_DIR}" ]; then
|
||||
# Try the default location in case it exists
|
||||
if [ -d /usr/lib/cmake/CLBlast ]; then
|
||||
export CLBlast_DIR=/usr/lib/cmake/CLBlast
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -d "${ROCM_PATH}" ]; then
|
||||
echo "ROCm libraries detected - building dynamic ROCm library"
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||
BUILD_DIR="gguf/build/linux/rocm"
|
||||
build
|
||||
install
|
||||
gcc -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.so \
|
||||
-Wl,--whole-archive \
|
||||
${BUILD_DIR}/lib/libext_server.a \
|
||||
${BUILD_DIR}/lib/libcommon.a \
|
||||
${BUILD_DIR}/lib/libllama.a \
|
||||
-Wl,--no-whole-archive \
|
||||
-lrt -lpthread -ldl -lstdc++ -lm \
|
||||
-L/opt/rocm/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ \
|
||||
-Wl,-rpath,/opt/rocm/lib,-rpath,/opt/amdgpu/lib/x86_64-linux-gnu/ \
|
||||
-lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu
|
||||
fi
|
||||
|
||||
cleanup
|
@@ -1,87 +0,0 @@
|
||||
#!powershell
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
function init_vars {
|
||||
$script:patches = @("0001-Expose-callable-API-for-server.patch")
|
||||
$script:cmakeDefs = @("-DBUILD_SHARED_LIBS=on", "-DLLAMA_NATIVE=off", "-DLLAMA_F16C=off", "-DLLAMA_FMA=off", "-DLLAMA_AVX512=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX=on", "-A","x64")
|
||||
$script:cmakeTargets = @("ggml", "ggml_static", "llama", "build_info", "common", "ext_server_shared", "llava_static")
|
||||
if ($env:CGO_CFLAGS -contains "-g") {
|
||||
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on")
|
||||
$script:config = "RelWithDebInfo"
|
||||
} else {
|
||||
$script:cmakeDefs += @("-DLLAMA_SERVER_VERBOSE=off")
|
||||
$script:config = "Release"
|
||||
}
|
||||
}
|
||||
|
||||
function git_module_setup {
|
||||
# TODO add flags to skip the init/patch logic to make it easier to mod llama.cpp code in-repo
|
||||
& git submodule init
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& git submodule update --force gguf
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function apply_patches {
|
||||
# Wire up our CMakefile
|
||||
if (!(Select-String -Path "gguf/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
|
||||
Add-Content -Path "gguf/examples/server/CMakeLists.txt" -Value 'include (../../../CMakeLists.txt) # ollama'
|
||||
}
|
||||
# Avoid duplicate main symbols when we link into the cgo binary
|
||||
$content = Get-Content -Path "./gguf/examples/server/server.cpp"
|
||||
$content = $content -replace 'int main\(', 'int __main('
|
||||
Set-Content -Path "./gguf/examples/server/server.cpp" -Value $content
|
||||
}
|
||||
|
||||
function build {
|
||||
write-host "generating config with: cmake -S gguf -B $script:buildDir $script:cmakeDefs"
|
||||
& cmake --version
|
||||
& cmake -S gguf -B $script:buildDir $script:cmakeDefs
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
write-host "building with: cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })"
|
||||
& cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function install {
|
||||
rm -ea 0 -recurse -force -path "${script:buildDir}/lib"
|
||||
md "${script:buildDir}/lib" -ea 0 > $null
|
||||
cp "${script:buildDir}/bin/${script:config}/ext_server_shared.dll" "${script:buildDir}/lib"
|
||||
cp "${script:buildDir}/bin/${script:config}/llama.dll" "${script:buildDir}/lib"
|
||||
|
||||
# Display the dll dependencies in the build log
|
||||
dumpbin /dependents "${script:buildDir}/bin/${script:config}/ext_server_shared.dll" | select-string ".dll"
|
||||
}
|
||||
|
||||
function cleanup {
|
||||
Set-Location "gguf/examples/server"
|
||||
git checkout CMakeLists.txt server.cpp
|
||||
}
|
||||
|
||||
init_vars
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
# first build CPU based
|
||||
$script:buildDir="gguf/build/windows/cpu"
|
||||
|
||||
build
|
||||
install
|
||||
|
||||
# Then build cuda as a dynamically loaded library
|
||||
init_vars
|
||||
$script:buildDir="gguf/build/windows/cuda"
|
||||
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON")
|
||||
build
|
||||
install
|
||||
|
||||
# TODO - actually implement ROCm support on windows
|
||||
$script:buildDir="gguf/build/windows/rocm"
|
||||
|
||||
rm -ea 0 -recurse -force -path "${script:buildDir}/lib"
|
||||
md "${script:buildDir}/lib" -ea 0 > $null
|
||||
echo $null >> "${script:buildDir}/lib/.generated"
|
||||
|
||||
cleanup
|
||||
write-host "`ngo generate completed"
|
Submodule llm/llama.cpp/gguf deleted from 328b83de23
110
llm/llama.go
110
llm/llama.go
@@ -1,18 +1,11 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/format"
|
||||
)
|
||||
|
||||
const jsonGrammar = `
|
||||
@@ -43,109 +36,12 @@ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
|
||||
type llamaModel struct {
|
||||
hyperparameters llamaHyperparameters
|
||||
}
|
||||
|
||||
func (llm *llamaModel) ModelFamily() string {
|
||||
return "llama"
|
||||
}
|
||||
|
||||
func llamaModelType(numLayer uint32) string {
|
||||
switch numLayer {
|
||||
case 26:
|
||||
return "3B"
|
||||
case 32:
|
||||
return "7B"
|
||||
case 40:
|
||||
return "13B"
|
||||
case 48:
|
||||
return "34B"
|
||||
case 60:
|
||||
return "30B"
|
||||
case 80:
|
||||
return "65B"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (llm *llamaModel) ModelType() string {
|
||||
return llamaModelType(llm.hyperparameters.NumLayer)
|
||||
}
|
||||
|
||||
func (llm *llamaModel) FileType() string {
|
||||
return fileType(llm.hyperparameters.FileType)
|
||||
}
|
||||
|
||||
func (llm *llamaModel) NumLayers() int64 {
|
||||
return int64(llm.hyperparameters.NumLayer)
|
||||
}
|
||||
|
||||
type llamaHyperparameters struct {
|
||||
// NumVocab is the size of the model's vocabulary.
|
||||
NumVocab uint32
|
||||
|
||||
// NumEmbd is the size of the model's embedding layer.
|
||||
NumEmbd uint32
|
||||
NumMult uint32
|
||||
NumHead uint32
|
||||
|
||||
// NumLayer is the number of layers in the model.
|
||||
NumLayer uint32
|
||||
NumRot uint32
|
||||
|
||||
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
|
||||
FileType uint32
|
||||
}
|
||||
|
||||
type Running struct {
|
||||
Port int
|
||||
Cmd *exec.Cmd
|
||||
Cancel context.CancelFunc
|
||||
exitOnce sync.Once
|
||||
exitCh chan error // channel to receive the exit status of the subprocess
|
||||
*StatusWriter // captures error messages from the llama runner process
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
var (
|
||||
errNvidiaSMI = errors.New("warning: gpu support may not be enabled, check that you have installed GPU drivers: nvidia-smi command failed")
|
||||
errAvailableVRAM = errors.New("not enough VRAM available, falling back to CPU only")
|
||||
payloadMissing = fmt.Errorf("expected dynamic library payloads not included in this build of ollama")
|
||||
)
|
||||
|
||||
// StatusWriter is a writer that captures error messages from the llama runner process
|
||||
type StatusWriter struct {
|
||||
ErrCh chan error
|
||||
LastErrMsg string
|
||||
}
|
||||
|
||||
func NewStatusWriter() *StatusWriter {
|
||||
return &StatusWriter{
|
||||
ErrCh: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||
var errMsg string
|
||||
if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
|
||||
errMsg = string(bytes.TrimSpace(after))
|
||||
} else if _, after, ok := bytes.Cut(b, []byte("CUDA error")); ok {
|
||||
errMsg = string(bytes.TrimSpace(after))
|
||||
}
|
||||
|
||||
if errMsg != "" {
|
||||
w.LastErrMsg = errMsg
|
||||
w.ErrCh <- fmt.Errorf("llama runner: %s", errMsg)
|
||||
}
|
||||
|
||||
return os.Stderr.Write(b)
|
||||
}
|
||||
var payloadMissing = fmt.Errorf("expected dynamic library payloads not included in this build of ollama")
|
||||
|
||||
type prediction struct {
|
||||
Content string `json:"content"`
|
||||
@@ -161,14 +57,12 @@ type prediction struct {
|
||||
}
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxRetries = 3
|
||||
const retryDelay = 1 * time.Second
|
||||
|
||||
type PredictOpts struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []api.ImageData
|
||||
Images []ImageData
|
||||
Options api.Options
|
||||
}
|
||||
|
||||
|
147
llm/llm.go
147
llm/llm.go
@@ -3,14 +3,11 @@ package llm
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/pbnjay/memory"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/format"
|
||||
"github.com/jmorganca/ollama/gpu"
|
||||
)
|
||||
|
||||
@@ -22,8 +19,6 @@ type LLM interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
var AvailableShims = map[string]string{}
|
||||
|
||||
func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
@@ -40,40 +35,92 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
var requiredMemory int64
|
||||
var f16Multiplier int64 = 2
|
||||
|
||||
switch ggml.ModelType() {
|
||||
case "3B", "7B":
|
||||
requiredMemory = 8 * format.GigaByte
|
||||
case "13B":
|
||||
requiredMemory = 16 * format.GigaByte
|
||||
case "30B", "34B", "40B":
|
||||
requiredMemory = 32 * format.GigaByte
|
||||
case "47B":
|
||||
requiredMemory = 48 * format.GigaByte
|
||||
case "65B", "70B":
|
||||
requiredMemory = 64 * format.GigaByte
|
||||
case "180B":
|
||||
requiredMemory = 128 * format.GigaByte
|
||||
f16Multiplier = 4
|
||||
}
|
||||
|
||||
systemMemory := int64(memory.TotalMemory())
|
||||
|
||||
if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > systemMemory {
|
||||
return nil, fmt.Errorf("F16 model requires at least %s of memory", format.HumanBytes(requiredMemory))
|
||||
} else if requiredMemory > systemMemory {
|
||||
return nil, fmt.Errorf("model requires at least %s of memory", format.HumanBytes(requiredMemory))
|
||||
}
|
||||
if opts.NumCtx > int(ggml.NumCtx()) {
|
||||
slog.Warn(fmt.Sprintf("requested context length is greater than model's max context length (%d > %d), using %d instead", opts.NumCtx, ggml.NumCtx(), ggml.NumCtx()))
|
||||
opts.NumCtx = int(ggml.NumCtx())
|
||||
}
|
||||
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
|
||||
vram, _ := gpu.CheckVRAM()
|
||||
size := ggml.Size
|
||||
|
||||
// fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
|
||||
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead())
|
||||
|
||||
// this amount is the overhead + tensors in memory
|
||||
// TODO: get this from the llama.cpp's graph calculations instead of
|
||||
// estimating it's 1/6 * kv_cache_size * num_gqa
|
||||
graph := int64(ggml.NumGQA()) * kv / 6
|
||||
|
||||
info := gpu.GetGPUInfo()
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if opts.NumGPU == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if size+kv+graph > vram {
|
||||
slog.Info("not enough vram available, falling back to CPU only")
|
||||
info.Library = "cpu"
|
||||
info.Variant = gpu.GetCPUVariant()
|
||||
opts.NumGPU = 0
|
||||
break
|
||||
}
|
||||
|
||||
// TODO: implement layer splitting on macOS
|
||||
opts.NumGPU = 999
|
||||
default:
|
||||
if info.Library == "cpu" {
|
||||
slog.Info("GPU not available, falling back to CPU")
|
||||
opts.NumGPU = 0
|
||||
break
|
||||
}
|
||||
|
||||
// don't use GPU at all if no layers are loaded
|
||||
if opts.NumGPU == 0 {
|
||||
info.Library = "cpu"
|
||||
info.Variant = gpu.GetCPUVariant()
|
||||
break
|
||||
}
|
||||
|
||||
// user-defined GPU count
|
||||
if opts.NumGPU != -1 {
|
||||
break
|
||||
}
|
||||
|
||||
// the "main" GPU needs the most memory and determines the limit
|
||||
// of how many layers can be loaded. It needs to fit:
|
||||
// 1. the full compute graph allocation for all devices (graph)
|
||||
// 2. the proportional kv cache for all devices (kv * % layers)
|
||||
// 3. the proportional model (size * % layers / # devices)
|
||||
// This estimates the number of layers
|
||||
maxlayers := int64(ggml.NumLayers()) + 1
|
||||
devices := int64(info.DeviceCount)
|
||||
avg := vram / devices
|
||||
layers := maxlayers * (avg - graph) / (kv + size/devices)
|
||||
if layers > maxlayers {
|
||||
layers = maxlayers
|
||||
}
|
||||
|
||||
// 1 + 2 must fit on the main gpu
|
||||
min := graph + kv*layers/maxlayers
|
||||
if layers <= 0 || min > avg {
|
||||
slog.Info("not enough vram available, falling back to CPU only")
|
||||
info.Library = "cpu"
|
||||
info.Variant = gpu.GetCPUVariant()
|
||||
opts.NumGPU = 0
|
||||
break
|
||||
}
|
||||
|
||||
opts.NumGPU = int(layers)
|
||||
}
|
||||
|
||||
opts.NumGQA = 0
|
||||
opts.RopeFrequencyBase = 0.0
|
||||
opts.RopeFrequencyScale = 0.0
|
||||
gpuInfo := gpu.GetGPUInfo()
|
||||
return newLlmServer(gpuInfo.Library, model, adapters, projectors, ggml.NumLayers(), opts)
|
||||
return newLlmServer(info, model, adapters, projectors, opts)
|
||||
}
|
||||
|
||||
// Give any native cgo implementations an opportunity to initialize
|
||||
@@ -81,16 +128,30 @@ func Init(workdir string) error {
|
||||
return nativeInit(workdir)
|
||||
}
|
||||
|
||||
func newLlmServer(library, model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
if _, libPresent := AvailableShims[library]; libPresent && library != "default" {
|
||||
srv, err := newDynamicShimExtServer(AvailableShims[library], model, adapters, projectors, numLayers, opts)
|
||||
func newLlmServer(gpuInfo gpu.GpuInfo, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
|
||||
dynLibs := getDynLibs(gpuInfo)
|
||||
|
||||
// Check to see if the user has requested a specific library instead of auto-detecting
|
||||
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
|
||||
if demandLib != "" {
|
||||
libPath := availableDynLibs[demandLib]
|
||||
if libPath == "" {
|
||||
slog.Info(fmt.Sprintf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib))
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("Loading OLLAMA_LLM_LIBRARY=%s", demandLib))
|
||||
dynLibs = []string{libPath}
|
||||
}
|
||||
}
|
||||
|
||||
err2 := fmt.Errorf("unable to locate suitable llm library")
|
||||
for _, dynLib := range dynLibs {
|
||||
srv, err := newDynExtServer(dynLib, model, adapters, projectors, opts)
|
||||
if err == nil {
|
||||
return srv, nil
|
||||
}
|
||||
log.Printf("Failed to load dynamic library %s - falling back to CPU mode %s", library, err)
|
||||
// TODO - update some state to indicate we were unable to load the GPU library for future "info" ux
|
||||
slog.Warn(fmt.Sprintf("Failed to load dynamic library %s %s", dynLib, err))
|
||||
err2 = err
|
||||
}
|
||||
|
||||
return newDefaultExtServer(model, adapters, projectors, numLayers, opts)
|
||||
|
||||
return nil, err2
|
||||
}
|
||||
|
30
llm/patches/01-cache.diff
Normal file
30
llm/patches/01-cache.diff
Normal file
@@ -0,0 +1,30 @@
|
||||
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
|
||||
index a48582ad..9fffffd8 100644
|
||||
--- a/examples/server/server.cpp
|
||||
+++ b/examples/server/server.cpp
|
||||
@@ -1564,12 +1564,6 @@ struct llama_server_context
|
||||
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
||||
}
|
||||
|
||||
- LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
|
||||
-
|
||||
- llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
|
||||
-
|
||||
- slot.cache_tokens = prompt_tokens;
|
||||
-
|
||||
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
||||
{
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
@@ -1581,6 +1575,12 @@ struct llama_server_context
|
||||
}
|
||||
}
|
||||
|
||||
+ LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
|
||||
+
|
||||
+ llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
|
||||
+
|
||||
+ slot.cache_tokens = prompt_tokens;
|
||||
+
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
{"n_past", slot.n_past},
|
||||
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
|
90
llm/patches/02-shutdown.diff
Normal file
90
llm/patches/02-shutdown.diff
Normal file
@@ -0,0 +1,90 @@
|
||||
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
|
||||
index 11dd82c3..311495a8 100644
|
||||
--- a/examples/server/server.cpp
|
||||
+++ b/examples/server/server.cpp
|
||||
@@ -28,6 +28,7 @@
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <atomic>
|
||||
+#include <signal.h>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
@@ -2394,6 +2395,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
|
||||
}
|
||||
}
|
||||
|
||||
+std::function<void(int)> shutdown_handler;
|
||||
+inline void signal_handler(int signal) { shutdown_handler(signal); }
|
||||
+
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
#if SERVER_VERBOSE != 1
|
||||
@@ -3014,8 +3018,14 @@ int main(int argc, char **argv)
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
- llama.queue_tasks.start_loop();
|
||||
|
||||
+ shutdown_handler = [&](int) {
|
||||
+ llama.queue_tasks.terminate();
|
||||
+ };
|
||||
+ signal(SIGTERM, signal_handler);
|
||||
+ signal(SIGINT, signal_handler);
|
||||
+ llama.queue_tasks.start_loop();
|
||||
+ svr.stop();
|
||||
t.join();
|
||||
|
||||
llama_backend_free();
|
||||
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
|
||||
index 70cce072..2acb1eab 100644
|
||||
--- a/examples/server/utils.hpp
|
||||
+++ b/examples/server/utils.hpp
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <unordered_map>
|
||||
+#include <atomic>
|
||||
|
||||
#include "json.hpp"
|
||||
|
||||
@@ -190,6 +191,7 @@ inline std::string format_chatml(std::vector<json> messages)
|
||||
struct llama_server_queue {
|
||||
int id = 0;
|
||||
std::mutex mutex_tasks;
|
||||
+ std::atomic<bool> running;
|
||||
// queues
|
||||
std::vector<task_server> queue_tasks;
|
||||
std::vector<task_server> queue_tasks_deferred;
|
||||
@@ -248,9 +250,15 @@ struct llama_server_queue {
|
||||
queue_tasks_deferred.clear();
|
||||
}
|
||||
|
||||
- // Start the main loop. This call is blocking
|
||||
- [[noreturn]]
|
||||
+ // end the start_loop routine
|
||||
+ void terminate() {
|
||||
+ running = false;
|
||||
+ condition_tasks.notify_all();
|
||||
+ }
|
||||
+
|
||||
+ // Start the main loop.
|
||||
void start_loop() {
|
||||
+ running = true;
|
||||
while (true) {
|
||||
// new task arrived
|
||||
LOG_VERBOSE("have new task", {});
|
||||
@@ -294,8 +302,12 @@ struct llama_server_queue {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (queue_tasks.empty()) {
|
||||
+ if (!running.load()) {
|
||||
+ LOG_VERBOSE("ending start_loop", {});
|
||||
+ return;
|
||||
+ }
|
||||
condition_tasks.wait(lock, [&]{
|
||||
- return !queue_tasks.empty();
|
||||
+ return (!queue_tasks.empty() || !running.load());
|
||||
});
|
||||
}
|
||||
}
|
283
llm/payload_common.go
Normal file
283
llm/payload_common.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/jmorganca/ollama/gpu"
|
||||
)
|
||||
|
||||
// Libraries names may contain an optional variant separated by '_'
|
||||
// For example, "rocm_v6" and "rocm_v5" or "cpu" and "cpu_avx2"
|
||||
// Any library without a variant is the lowest common denominator
|
||||
var availableDynLibs = map[string]string{}
|
||||
|
||||
const pathComponentCount = 7
|
||||
|
||||
// getDynLibs returns an ordered list of LLM libraries to try, starting with the best
|
||||
func getDynLibs(gpuInfo gpu.GpuInfo) []string {
|
||||
// Short circuit if we know we're using the default built-in (darwin only)
|
||||
if gpuInfo.Library == "default" {
|
||||
return []string{"default"}
|
||||
}
|
||||
// TODO - temporary until we have multiple CPU variations for Darwin
|
||||
// Short circuit on darwin with metal only
|
||||
if len(availableDynLibs) == 1 {
|
||||
if _, onlyMetal := availableDynLibs["metal"]; onlyMetal {
|
||||
return []string{availableDynLibs["metal"]}
|
||||
}
|
||||
}
|
||||
|
||||
exactMatch := ""
|
||||
dynLibs := []string{}
|
||||
altDynLibs := []string{}
|
||||
requested := gpuInfo.Library
|
||||
if gpuInfo.Variant != "" {
|
||||
requested += "_" + gpuInfo.Variant
|
||||
}
|
||||
// Try to find an exact match
|
||||
for cmp := range availableDynLibs {
|
||||
if requested == cmp {
|
||||
exactMatch = cmp
|
||||
dynLibs = []string{availableDynLibs[cmp]}
|
||||
break
|
||||
}
|
||||
}
|
||||
// Then for GPUs load alternates and sort the list for consistent load ordering
|
||||
if gpuInfo.Library != "cpu" {
|
||||
for cmp := range availableDynLibs {
|
||||
if gpuInfo.Library == strings.Split(cmp, "_")[0] && cmp != exactMatch {
|
||||
altDynLibs = append(altDynLibs, cmp)
|
||||
}
|
||||
}
|
||||
slices.Sort(altDynLibs)
|
||||
for _, altDynLib := range altDynLibs {
|
||||
dynLibs = append(dynLibs, availableDynLibs[altDynLib])
|
||||
}
|
||||
}
|
||||
|
||||
// Load up the best CPU variant if not primary requested
|
||||
if gpuInfo.Library != "cpu" {
|
||||
variant := gpu.GetCPUVariant()
|
||||
// If no variant, then we fall back to default
|
||||
// If we have a variant, try that if we find an exact match
|
||||
// Attempting to run the wrong CPU instructions will panic the
|
||||
// process
|
||||
if variant != "" {
|
||||
for cmp := range availableDynLibs {
|
||||
if cmp == "cpu_"+variant {
|
||||
dynLibs = append(dynLibs, availableDynLibs[cmp])
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
dynLibs = append(dynLibs, availableDynLibs["cpu"])
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, if we didn't find any matches, LCD CPU FTW
|
||||
if len(dynLibs) == 0 {
|
||||
dynLibs = []string{availableDynLibs["cpu"]}
|
||||
}
|
||||
return dynLibs
|
||||
}
|
||||
|
||||
func rocmDynLibPresent() bool {
|
||||
for dynLibName := range availableDynLibs {
|
||||
if strings.HasPrefix(dynLibName, "rocm") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func nativeInit(workdir string) error {
|
||||
slog.Info("Extracting dynamic libraries...")
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := extractPayloadFiles(workdir, "llama.cpp/ggml-metal.metal")
|
||||
if err != nil {
|
||||
if err == payloadMissing {
|
||||
// TODO perhaps consider this a hard failure on arm macs?
|
||||
slog.Info("ggml-meta.metal payload missing")
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
os.Setenv("GGML_METAL_PATH_RESOURCES", workdir)
|
||||
}
|
||||
|
||||
libs, err := extractDynamicLibs(workdir, "llama.cpp/build/*/*/*/lib/*")
|
||||
if err != nil {
|
||||
if err == payloadMissing {
|
||||
slog.Info(fmt.Sprintf("%s", payloadMissing))
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
for _, lib := range libs {
|
||||
// The last dir component is the variant name
|
||||
variant := filepath.Base(filepath.Dir(lib))
|
||||
availableDynLibs[variant] = lib
|
||||
}
|
||||
|
||||
if err := verifyDriverAccess(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Report which dynamic libraries we have loaded to assist troubleshooting
|
||||
variants := make([]string, len(availableDynLibs))
|
||||
i := 0
|
||||
for variant := range availableDynLibs {
|
||||
variants[i] = variant
|
||||
i++
|
||||
}
|
||||
slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
|
||||
slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractDynamicLibs(workDir, glob string) ([]string, error) {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return nil, payloadMissing
|
||||
}
|
||||
libs := []string{}
|
||||
|
||||
// TODO consider making this idempotent with some sort of persistent directory (where we store models probably)
|
||||
// and tracking by version so we don't reexpand the files every time
|
||||
// Also maybe consider lazy loading only what is needed
|
||||
|
||||
g := new(errgroup.Group)
|
||||
for _, file := range files {
|
||||
pathComps := strings.Split(file, "/")
|
||||
if len(pathComps) != pathComponentCount {
|
||||
slog.Error(fmt.Sprintf("unexpected payload components: %v", pathComps))
|
||||
continue
|
||||
}
|
||||
|
||||
file := file
|
||||
g.Go(func() error {
|
||||
// llama.cpp/build/$OS/$GOARCH/$VARIANT/lib/$LIBRARY
|
||||
// Include the variant in the path to avoid conflicts between multiple server libs
|
||||
targetDir := filepath.Join(workDir, pathComps[pathComponentCount-3])
|
||||
srcFile, err := libEmbed.Open(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read payload %s: %v", file, err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create payload temp dir %s: %v", workDir, err)
|
||||
}
|
||||
src := io.Reader(srcFile)
|
||||
filename := file
|
||||
if strings.HasSuffix(file, ".gz") {
|
||||
src, err = gzip.NewReader(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decompress payload %s: %v", file, err)
|
||||
}
|
||||
filename = strings.TrimSuffix(filename, ".gz")
|
||||
}
|
||||
|
||||
destFile := filepath.Join(targetDir, filepath.Base(filename))
|
||||
if strings.Contains(destFile, "server") {
|
||||
libs = append(libs, destFile)
|
||||
}
|
||||
|
||||
_, err = os.Stat(destFile)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write payload %s: %v", file, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, src); err != nil {
|
||||
return fmt.Errorf("copy payload %s: %v", file, err)
|
||||
}
|
||||
case err != nil:
|
||||
return fmt.Errorf("stat payload %s: %v", file, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return libs, g.Wait()
|
||||
}
|
||||
|
||||
func extractPayloadFiles(workDir, glob string) error {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return payloadMissing
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
srcFile, err := libEmbed.Open(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read payload %s: %v", file, err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
if err := os.MkdirAll(workDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create payload temp dir %s: %v", workDir, err)
|
||||
}
|
||||
src := io.Reader(srcFile)
|
||||
filename := file
|
||||
if strings.HasSuffix(file, ".gz") {
|
||||
src, err = gzip.NewReader(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decompress payload %s: %v", file, err)
|
||||
}
|
||||
filename = strings.TrimSuffix(filename, ".gz")
|
||||
}
|
||||
|
||||
destFile := filepath.Join(workDir, filepath.Base(filename))
|
||||
_, err = os.Stat(destFile)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write payload %s: %v", file, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, src); err != nil {
|
||||
return fmt.Errorf("copy payload %s: %v", file, err)
|
||||
}
|
||||
case err != nil:
|
||||
return fmt.Errorf("stat payload %s: %v", file, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyDriverAccess() error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
// Only check ROCm access if we have the dynamic lib loaded
|
||||
if rocmDynLibPresent() {
|
||||
// Verify we have permissions - either running as root, or we have group access to the driver
|
||||
fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrPermission) {
|
||||
return fmt.Errorf("Radeon card detected, but permissions not set up properly. Either run ollama as root, or add you user account to the render group.")
|
||||
} else if errors.Is(err, fs.ErrNotExist) {
|
||||
// expected behavior without a radeon card
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
|
||||
}
|
||||
fd.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
8
llm/payload_darwin_amd64.go
Normal file
8
llm/payload_darwin_amd64.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/ggml-metal.metal llama.cpp/build/darwin/x86_64/*/lib/*.dylib*
|
||||
var libEmbed embed.FS
|
8
llm/payload_darwin_arm64.go
Normal file
8
llm/payload_darwin_arm64.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/ggml-metal.metal llama.cpp/build/darwin/arm64/*/lib/*.dylib*
|
||||
var libEmbed embed.FS
|
8
llm/payload_linux.go
Normal file
8
llm/payload_linux.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/build/linux/*/*/lib/*.so*
|
||||
var libEmbed embed.FS
|
58
llm/payload_test.go
Normal file
58
llm/payload_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jmorganca/ollama/gpu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetDynLibs(t *testing.T) {
|
||||
availableDynLibs = map[string]string{
|
||||
"cpu": "X_cpu",
|
||||
}
|
||||
assert.Equal(t, false, rocmDynLibPresent())
|
||||
res := getDynLibs(gpu.GpuInfo{Library: "cpu"})
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, availableDynLibs["cpu"], res[0])
|
||||
|
||||
variant := gpu.GetCPUVariant()
|
||||
if variant != "" {
|
||||
variant = "_" + variant
|
||||
}
|
||||
availableDynLibs = map[string]string{
|
||||
"rocm_v5": "X_rocm_v5",
|
||||
"rocm_v6": "X_rocm_v6",
|
||||
"cpu" + variant: "X_cpu",
|
||||
}
|
||||
assert.Equal(t, true, rocmDynLibPresent())
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "rocm"})
|
||||
assert.Len(t, res, 3)
|
||||
assert.Equal(t, availableDynLibs["rocm_v5"], res[0])
|
||||
assert.Equal(t, availableDynLibs["rocm_v6"], res[1])
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[2])
|
||||
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
|
||||
assert.Len(t, res, 3)
|
||||
assert.Equal(t, availableDynLibs["rocm_v6"], res[0])
|
||||
assert.Equal(t, availableDynLibs["rocm_v5"], res[1])
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[2])
|
||||
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "cuda"})
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[0])
|
||||
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "default"})
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, "default", res[0])
|
||||
|
||||
availableDynLibs = map[string]string{
|
||||
"rocm": "X_rocm_v5",
|
||||
"cpu" + variant: "X_cpu",
|
||||
}
|
||||
assert.Equal(t, true, rocmDynLibPresent())
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
|
||||
assert.Len(t, res, 2)
|
||||
assert.Equal(t, availableDynLibs["rocm"], res[0])
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[1])
|
||||
}
|
8
llm/payload_windows.go
Normal file
8
llm/payload_windows.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/build/windows/*/*/lib/*.dll*
|
||||
var libEmbed embed.FS
|
@@ -1,71 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/gguf/ggml-metal.metal
|
||||
var libEmbed embed.FS
|
||||
|
||||
func newDynamicShimExtServer(library, model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
// should never happen...
|
||||
return nil, fmt.Errorf("Dynamic library loading not supported on Mac")
|
||||
}
|
||||
|
||||
func nativeInit(workdir string) error {
|
||||
err := extractPayloadFiles(workdir, "llama.cpp/gguf/ggml-metal.metal")
|
||||
if err != nil {
|
||||
if err == payloadMissing {
|
||||
// TODO perhaps consider this a hard failure on arm macs?
|
||||
log.Printf("ggml-meta.metal payload missing")
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
os.Setenv("GGML_METAL_PATH_RESOURCES", workdir)
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractPayloadFiles(workDir, glob string) error {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return payloadMissing
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
srcFile, err := libEmbed.Open(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read payload %s: %v", file, err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
if err := os.MkdirAll(workDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create payload temp dir %s: %v", workDir, err)
|
||||
}
|
||||
|
||||
destFile := filepath.Join(workDir, filepath.Base(file))
|
||||
_, err = os.Stat(destFile)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write payload %s: %v", file, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, srcFile); err != nil {
|
||||
return fmt.Errorf("copy payload %s: %v", file, err)
|
||||
}
|
||||
case err != nil:
|
||||
return fmt.Errorf("stat payload %s: %v", file, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,191 +0,0 @@
|
||||
//go:build !darwin
|
||||
|
||||
package llm
|
||||
|
||||
/*
|
||||
|
||||
#include <stdlib.h>
|
||||
#include "dynamic_shim.h"
|
||||
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
type shimExtServer struct {
|
||||
s C.struct_dynamic_llama_server
|
||||
options api.Options
|
||||
}
|
||||
|
||||
// Note: current implementation does not support concurrent instantiations
|
||||
var shimMutex sync.Mutex
|
||||
var llm *shimExtServer
|
||||
|
||||
func (llm *shimExtServer) llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t) {
|
||||
C.dynamic_shim_llama_server_init(llm.s, sparams, err)
|
||||
}
|
||||
func (llm *shimExtServer) llama_server_start() {
|
||||
C.dynamic_shim_llama_server_start(llm.s)
|
||||
}
|
||||
func (llm *shimExtServer) llama_server_stop() {
|
||||
C.dynamic_shim_llama_server_stop(llm.s)
|
||||
}
|
||||
|
||||
func (llm *shimExtServer) llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t) {
|
||||
C.dynamic_shim_llama_server_completion(llm.s, json_req, resp)
|
||||
}
|
||||
func (llm *shimExtServer) llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t) {
|
||||
C.dynamic_shim_llama_server_completion_next_result(llm.s, task_id, resp)
|
||||
}
|
||||
func (llm *shimExtServer) llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t) {
|
||||
C.dynamic_shim_llama_server_completion_cancel(llm.s, task_id, err)
|
||||
}
|
||||
func (llm *shimExtServer) llama_server_release_task_result(result *C.ext_server_task_result_t) {
|
||||
C.dynamic_shim_llama_server_release_task_result(llm.s, result)
|
||||
}
|
||||
|
||||
func (llm *shimExtServer) llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.dynamic_shim_llama_server_tokenize(llm.s, json_req, json_resp, err)
|
||||
}
|
||||
func (llm *shimExtServer) llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.dynamic_shim_llama_server_detokenize(llm.s, json_req, json_resp, err)
|
||||
}
|
||||
func (llm *shimExtServer) llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.dynamic_shim_llama_server_embedding(llm.s, json_req, json_resp, err)
|
||||
}
|
||||
func (llm *shimExtServer) llama_server_release_json_resp(json_resp **C.char) {
|
||||
C.dynamic_shim_llama_server_release_json_resp(llm.s, json_resp)
|
||||
}
|
||||
|
||||
func newDynamicShimExtServer(library, model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
shimMutex.Lock()
|
||||
defer shimMutex.Unlock()
|
||||
updatePath(filepath.Dir(library))
|
||||
libPath := C.CString(library)
|
||||
defer C.free(unsafe.Pointer(libPath))
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
var srv C.struct_dynamic_llama_server
|
||||
C.dynamic_shim_init(libPath, &srv, &resp)
|
||||
if resp.id < 0 {
|
||||
return nil, fmt.Errorf("Unable to load dynamic library: %s", C.GoString(resp.msg))
|
||||
}
|
||||
llm = &shimExtServer{
|
||||
s: srv,
|
||||
options: opts,
|
||||
}
|
||||
log.Printf("Loading Dynamic Shim llm server: %s", library)
|
||||
return newExtServer(llm, model, adapters, projectors, numLayers, opts)
|
||||
}
|
||||
|
||||
func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
||||
return predict(ctx, llm, pred, fn)
|
||||
}
|
||||
|
||||
func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
return encode(llm, ctx, prompt)
|
||||
}
|
||||
|
||||
func (llm *shimExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return decode(llm, ctx, tokens)
|
||||
}
|
||||
|
||||
func (llm *shimExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
return embedding(llm, ctx, input)
|
||||
}
|
||||
|
||||
func (llm *shimExtServer) Close() {
|
||||
close(llm)
|
||||
}
|
||||
|
||||
func nativeInit(workdir string) error {
|
||||
libs, err := extractDynamicLibs(workdir, "llama.cpp/gguf/build/*/*/lib/*")
|
||||
if err != nil {
|
||||
if err == payloadMissing {
|
||||
log.Printf("%s", payloadMissing)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
for _, lib := range libs {
|
||||
// The last dir component is the variant name
|
||||
variant := filepath.Base(filepath.Dir(lib))
|
||||
AvailableShims[variant] = lib
|
||||
}
|
||||
|
||||
if err := verifyDriverAccess(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Report which dynamic libraries we have loaded to assist troubleshooting
|
||||
variants := make([]string, len(AvailableShims))
|
||||
i := 0
|
||||
for variant := range AvailableShims {
|
||||
variants[i] = variant
|
||||
i++
|
||||
}
|
||||
log.Printf("Dynamic LLM variants %v", variants)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractDynamicLibs(workDir, glob string) ([]string, error) {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return nil, payloadMissing
|
||||
}
|
||||
libs := []string{}
|
||||
|
||||
for _, file := range files {
|
||||
pathComps := strings.Split(file, "/")
|
||||
if len(pathComps) != 7 {
|
||||
log.Printf("unexpected payload components: %v", pathComps)
|
||||
continue
|
||||
}
|
||||
// llama.cpp/gguf/build/$OS/$VARIANT/lib/$LIBRARY
|
||||
// Include the variant in the path to avoid conflicts between multiple server libs
|
||||
targetDir := filepath.Join(workDir, pathComps[4])
|
||||
srcFile, err := libEmbed.Open(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read payload %s: %v", file, err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create payload temp dir %s: %v", workDir, err)
|
||||
}
|
||||
|
||||
destFile := filepath.Join(targetDir, filepath.Base(file))
|
||||
if strings.Contains(destFile, "server") {
|
||||
libs = append(libs, destFile)
|
||||
}
|
||||
|
||||
_, err = os.Stat(destFile)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write payload %s: %v", file, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, srcFile); err != nil {
|
||||
return nil, fmt.Errorf("copy payload %s: %v", file, err)
|
||||
}
|
||||
case err != nil:
|
||||
return nil, fmt.Errorf("stat payload %s: %v", file, err)
|
||||
}
|
||||
}
|
||||
return libs, nil
|
||||
}
|
@@ -1,46 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/gguf/build/*/*/lib/*.so
|
||||
var libEmbed embed.FS
|
||||
|
||||
func updatePath(dir string) {
|
||||
pathComponents := strings.Split(os.Getenv("PATH"), ":")
|
||||
for _, comp := range pathComponents {
|
||||
if comp == dir {
|
||||
return
|
||||
}
|
||||
}
|
||||
newPath := strings.Join(append(pathComponents, dir), ":")
|
||||
log.Printf("Updating PATH to %s", newPath)
|
||||
os.Setenv("PATH", newPath)
|
||||
}
|
||||
|
||||
func verifyDriverAccess() error {
|
||||
// Only check ROCm access if we have the dynamic lib loaded
|
||||
if _, rocmPresent := AvailableShims["rocm"]; rocmPresent {
|
||||
// Verify we have permissions - either running as root, or we have group access to the driver
|
||||
fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrPermission) {
|
||||
return fmt.Errorf("Radeon card detected, but permissions not set up properly. Either run ollama as root, or add you user account to the render group.")
|
||||
} else if errors.Is(err, fs.ErrNotExist) {
|
||||
// expected behavior without a radeon card
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
|
||||
}
|
||||
fd.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,36 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/gguf/build/windows/*/lib/*.dll
|
||||
var libEmbed embed.FS
|
||||
|
||||
func updatePath(dir string) {
|
||||
tmpDir := filepath.Dir(dir)
|
||||
pathComponents := strings.Split(os.Getenv("PATH"), ";")
|
||||
i := 0
|
||||
for _, comp := range pathComponents {
|
||||
if strings.EqualFold(comp, dir) {
|
||||
return
|
||||
}
|
||||
// Remove any other prior paths to our temp dir
|
||||
if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) {
|
||||
pathComponents[i] = comp
|
||||
i++
|
||||
}
|
||||
}
|
||||
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
|
||||
log.Printf("Updating PATH to %s", newPath)
|
||||
os.Setenv("PATH", newPath)
|
||||
}
|
||||
|
||||
func verifyDriverAccess() error {
|
||||
// TODO if applicable
|
||||
return nil
|
||||
}
|
@@ -6,7 +6,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type Command struct {
|
||||
@@ -56,10 +57,20 @@ func Parse(reader io.Reader) ([]Command, error) {
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
case "EMBED":
|
||||
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
||||
case "MESSAGE":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
||||
if len(fields) < 2 {
|
||||
return nil, fmt.Errorf("should be in the format <role> <message>")
|
||||
}
|
||||
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
|
||||
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
}
|
||||
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
|
||||
default:
|
||||
if !bytes.HasPrefix(fields[0], []byte("#")) {
|
||||
// log a warning for unknown commands
|
||||
log.Printf("WARNING: Unknown command: %s", fields[0])
|
||||
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
98
parser/parser_test.go
Normal file
98
parser/parser_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Parser(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM model1
|
||||
ADAPTER adapter1
|
||||
LICENSE MIT
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
TEMPLATE template1
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
commands, err := Parse(reader)
|
||||
assert.Nil(t, err)
|
||||
|
||||
expectedCommands := []Command{
|
||||
{Name: "model", Args: "model1"},
|
||||
{Name: "adapter", Args: "adapter1"},
|
||||
{Name: "license", Args: "MIT"},
|
||||
{Name: "param1", Args: "value1"},
|
||||
{Name: "param2", Args: "value2"},
|
||||
{Name: "template", Args: "template1"},
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedCommands, commands)
|
||||
}
|
||||
|
||||
func Test_Parser_NoFromLine(t *testing.T) {
|
||||
|
||||
input := `
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "no FROM line")
|
||||
}
|
||||
|
||||
func Test_Parser_MissingValue(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM foo
|
||||
PARAMETER param1
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "missing value for [param1]")
|
||||
|
||||
}
|
||||
|
||||
func Test_Parser_Messages(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM foo
|
||||
MESSAGE system You are a Parser. Always Parse things.
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
commands, err := Parse(reader)
|
||||
assert.Nil(t, err)
|
||||
|
||||
expectedCommands := []Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||
{Name: "message", Args: "user: Hey there!"},
|
||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedCommands, commands)
|
||||
}
|
||||
|
||||
func Test_Parser_Messages_BadRole(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM foo
|
||||
MESSAGE badguy I'm a bad guy!
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
}
|
@@ -77,7 +77,7 @@ func (p *Progress) Add(key string, state State) {
|
||||
p.states = append(p.states, state)
|
||||
}
|
||||
|
||||
func (p *Progress) render() error {
|
||||
func (p *Progress) render() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
@@ -101,8 +101,6 @@ func (p *Progress) render() error {
|
||||
}
|
||||
|
||||
p.pos = len(p.states)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Progress) start() {
|
||||
|
@@ -25,10 +25,7 @@ func NewBuffer(prompt *Prompt) (*Buffer, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lwidth := width - len(prompt.Prompt)
|
||||
if prompt.UseAlt {
|
||||
lwidth = width - len(prompt.AltPrompt)
|
||||
}
|
||||
lwidth := width - len(prompt.prompt())
|
||||
|
||||
b := &Buffer{
|
||||
Pos: 0,
|
||||
@@ -78,7 +75,7 @@ func (b *Buffer) MoveRight() {
|
||||
if b.Pos < b.Size() {
|
||||
b.Pos += 1
|
||||
if b.Pos%b.LineWidth == 0 {
|
||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(b.PromptSize()))
|
||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
||||
} else {
|
||||
fmt.Print(CursorRight)
|
||||
}
|
||||
@@ -109,7 +106,7 @@ func (b *Buffer) MoveToStart() {
|
||||
fmt.Print(CursorUp)
|
||||
}
|
||||
}
|
||||
fmt.Printf(CursorBOL + cursorRightN(b.PromptSize()))
|
||||
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
||||
b.Pos = 0
|
||||
}
|
||||
}
|
||||
@@ -123,7 +120,7 @@ func (b *Buffer) MoveToEnd() {
|
||||
fmt.Print(CursorDown)
|
||||
}
|
||||
remainder := b.Size() % b.LineWidth
|
||||
fmt.Printf(CursorBOL + cursorRightN(b.PromptSize()+remainder))
|
||||
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())+remainder))
|
||||
} else {
|
||||
fmt.Print(cursorRightN(b.Size() - b.Pos))
|
||||
}
|
||||
@@ -136,20 +133,6 @@ func (b *Buffer) Size() int {
|
||||
return b.Buf.Size()
|
||||
}
|
||||
|
||||
func min(n, m int) int {
|
||||
if n > m {
|
||||
return m
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (b *Buffer) PromptSize() int {
|
||||
if b.Prompt.UseAlt {
|
||||
return len(b.Prompt.AltPrompt)
|
||||
}
|
||||
return len(b.Prompt.Prompt)
|
||||
}
|
||||
|
||||
func (b *Buffer) Add(r rune) {
|
||||
if b.Pos == b.Buf.Size() {
|
||||
fmt.Printf("%c", r)
|
||||
@@ -232,7 +215,7 @@ func (b *Buffer) Remove() {
|
||||
remainingLines := (b.Size() - b.Pos) / b.LineWidth
|
||||
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
|
||||
place := b.Pos % b.LineWidth
|
||||
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.Prompt)))
|
||||
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.prompt())))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -247,7 +230,7 @@ func (b *Buffer) Delete() {
|
||||
remainingLines := (b.Size() - b.Pos) / b.LineWidth
|
||||
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
|
||||
place := b.Pos % b.LineWidth
|
||||
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.Prompt)))
|
||||
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.prompt())))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -294,15 +277,15 @@ func (b *Buffer) DeleteWord() {
|
||||
}
|
||||
|
||||
func (b *Buffer) ClearScreen() {
|
||||
fmt.Printf(ClearScreen + CursorReset + b.Prompt.Prompt)
|
||||
fmt.Printf(ClearScreen + CursorReset + b.Prompt.prompt())
|
||||
if b.IsEmpty() {
|
||||
ph := b.Prompt.Placeholder
|
||||
ph := b.Prompt.placeholder()
|
||||
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
|
||||
} else {
|
||||
currPos := b.Pos
|
||||
b.Pos = 0
|
||||
b.drawRemaining()
|
||||
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.Prompt)))
|
||||
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.prompt())))
|
||||
if currPos > 0 {
|
||||
targetLine := currPos / b.LineWidth
|
||||
if targetLine > 0 {
|
||||
@@ -329,7 +312,7 @@ func (b *Buffer) IsEmpty() bool {
|
||||
func (b *Buffer) Replace(r []rune) {
|
||||
b.Pos = 0
|
||||
b.Buf.Clear()
|
||||
fmt.Printf(ClearLine + CursorBOL + b.Prompt.Prompt)
|
||||
fmt.Printf(ClearLine + CursorBOL + b.Prompt.prompt())
|
||||
for _, c := range r {
|
||||
b.Add(c)
|
||||
}
|
||||
|
@@ -23,7 +23,7 @@ type History struct {
|
||||
func NewHistory() (*History, error) {
|
||||
h := &History{
|
||||
Buf: arraylist.New(),
|
||||
Limit: 100, //resizeme
|
||||
Limit: 100, // resizeme
|
||||
Autosave: true,
|
||||
Enabled: true,
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func (h *History) Init() error {
|
||||
|
||||
h.Filename = path
|
||||
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0600)
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0o600)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
@@ -84,7 +84,7 @@ func (h *History) Add(l []rune) {
|
||||
h.Compact()
|
||||
h.Pos = h.Size()
|
||||
if h.Autosave {
|
||||
h.Save()
|
||||
_ = h.Save()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func (h *History) Save() error {
|
||||
|
||||
tmpFile := h.Filename + ".tmp"
|
||||
|
||||
f, err := os.OpenFile(tmpFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC|os.O_APPEND, 0666)
|
||||
f, err := os.OpenFile(tmpFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC|os.O_APPEND, 0o600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -16,6 +16,20 @@ type Prompt struct {
|
||||
UseAlt bool
|
||||
}
|
||||
|
||||
func (p *Prompt) prompt() string {
|
||||
if p.UseAlt {
|
||||
return p.AltPrompt
|
||||
}
|
||||
return p.Prompt
|
||||
}
|
||||
|
||||
func (p *Prompt) placeholder() string {
|
||||
if p.UseAlt {
|
||||
return p.AltPlaceholder
|
||||
}
|
||||
return p.Placeholder
|
||||
}
|
||||
|
||||
type Terminal struct {
|
||||
outchan chan rune
|
||||
}
|
||||
@@ -46,8 +60,9 @@ func New(prompt Prompt) (*Instance, error) {
|
||||
}
|
||||
|
||||
func (i *Instance) Readline() (string, error) {
|
||||
prompt := i.Prompt.Prompt
|
||||
if i.Prompt.UseAlt || i.Pasting {
|
||||
prompt := i.Prompt.prompt()
|
||||
if i.Pasting {
|
||||
// force alt prompt when pasting
|
||||
prompt = i.Prompt.AltPrompt
|
||||
}
|
||||
fmt.Print(prompt)
|
||||
@@ -57,6 +72,7 @@ func (i *Instance) Readline() (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// nolint: errcheck
|
||||
defer UnsetRawMode(fd, termios)
|
||||
|
||||
buf, _ := NewBuffer(i.Prompt)
|
||||
@@ -71,10 +87,7 @@ func (i *Instance) Readline() (string, error) {
|
||||
// don't show placeholder when pasting unless we're in multiline mode
|
||||
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
|
||||
if buf.IsEmpty() && showPlaceholder {
|
||||
ph := i.Prompt.Placeholder
|
||||
if i.Prompt.UseAlt {
|
||||
ph = i.Prompt.AltPlaceholder
|
||||
}
|
||||
ph := i.Prompt.placeholder()
|
||||
fmt.Printf(ColorGrey + ph + fmt.Sprintf(CursorLeftN, len(ph)) + ColorDefault)
|
||||
}
|
||||
|
||||
|
@@ -11,7 +11,7 @@ func handleCharCtrlZ(fd int, termios *Termios) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
syscall.Kill(0, syscall.SIGSTOP)
|
||||
_ = syscall.Kill(0, syscall.SIGSTOP)
|
||||
|
||||
// on resume...
|
||||
return "", nil
|
||||
|
@@ -1,31 +1,46 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -eu
|
||||
set -e
|
||||
|
||||
export VERSION=${VERSION:-0.0.0}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
|
||||
|
||||
mkdir -p dist
|
||||
|
||||
for TARGETARCH in arm64 amd64; do
|
||||
rm -rf llm/llama.cpp/build
|
||||
GOOS=darwin GOARCH=$TARGETARCH go generate ./...
|
||||
CGO_ENABLED=1 GOOS=darwin GOARCH=$TARGETARCH go build -o dist/ollama-darwin-$TARGETARCH
|
||||
rm -rf llm/llama.cpp/*/build
|
||||
CGO_ENABLED=1 GOOS=darwin GOARCH=$TARGETARCH go build -cover -o dist/ollama-darwin-$TARGETARCH-cov
|
||||
done
|
||||
|
||||
lipo -create -output dist/ollama dist/ollama-darwin-*
|
||||
rm -f dist/ollama-darwin-*
|
||||
codesign --deep --force --options=runtime --sign "$APPLE_IDENTITY" --timestamp dist/ollama
|
||||
lipo -create -output dist/ollama dist/ollama-darwin-arm64 dist/ollama-darwin-amd64
|
||||
rm -f dist/ollama-darwin-arm64 dist/ollama-darwin-amd64
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign --deep --force --options=runtime --sign "$APPLE_IDENTITY" --timestamp dist/ollama
|
||||
else
|
||||
echo "Skipping code signing - set APPLE_IDENTITY"
|
||||
fi
|
||||
chmod +x dist/ollama
|
||||
|
||||
# build and sign the mac app
|
||||
# build and optionally sign the mac app
|
||||
npm install --prefix app
|
||||
npm run --prefix app make:sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
npm run --prefix app make:sign
|
||||
else
|
||||
npm run --prefix app make
|
||||
fi
|
||||
cp app/out/make/zip/darwin/universal/Ollama-darwin-universal-$VERSION.zip dist/Ollama-darwin.zip
|
||||
|
||||
# sign the binary and rename it
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/ollama
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/ollama
|
||||
else
|
||||
echo "WARNING: Skipping code signing - set APPLE_IDENTITY"
|
||||
fi
|
||||
ditto -c -k --keepParent dist/ollama dist/temp.zip
|
||||
xcrun notarytool submit dist/temp.zip --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
xcrun notarytool submit dist/temp.zip --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
fi
|
||||
mv dist/ollama dist/ollama-darwin
|
||||
rm -f dist/temp.zip
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
set -eu
|
||||
|
||||
export VERSION=${VERSION:-0.0.0}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
|
||||
|
||||
docker build \
|
||||
@@ -13,3 +13,13 @@ docker build \
|
||||
-f Dockerfile \
|
||||
-t ollama/ollama:$VERSION \
|
||||
.
|
||||
|
||||
docker build \
|
||||
--load \
|
||||
--platform=linux/amd64 \
|
||||
--build-arg=VERSION \
|
||||
--build-arg=GOFLAGS \
|
||||
--target runtime-rocm \
|
||||
-f Dockerfile \
|
||||
-t ollama/ollama:$VERSION-rocm \
|
||||
.
|
||||
|
@@ -2,13 +2,24 @@
|
||||
|
||||
set -eu
|
||||
|
||||
export VERSION=${VERSION:-0.0.0}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
|
||||
|
||||
BUILD_ARCH=${BUILD_ARCH:-"amd64 arm64"}
|
||||
export AMDGPU_TARGETS=${AMDGPU_TARGETS:=""}
|
||||
mkdir -p dist
|
||||
|
||||
for TARGETARCH in amd64 arm64; do
|
||||
docker build --platform=linux/$TARGETARCH --build-arg=VERSION --build-arg=GOFLAGS --build-arg=CGO_CFLAGS -f Dockerfile.build -t builder:$TARGETARCH .
|
||||
for TARGETARCH in ${BUILD_ARCH}; do
|
||||
docker build \
|
||||
--platform=linux/$TARGETARCH \
|
||||
--build-arg=GOFLAGS \
|
||||
--build-arg=CGO_CFLAGS \
|
||||
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \
|
||||
--build-arg=AMDGPU_TARGETS \
|
||||
--target build-$TARGETARCH \
|
||||
-f Dockerfile \
|
||||
-t builder:$TARGETARCH \
|
||||
.
|
||||
docker create --platform linux/$TARGETARCH --name builder-$TARGETARCH builder:$TARGETARCH
|
||||
docker cp builder-$TARGETARCH:/go/src/github.com/jmorganca/ollama/ollama ./dist/ollama-linux-$TARGETARCH
|
||||
docker rm builder-$TARGETARCH
|
||||
|
@@ -66,3 +66,7 @@ subprocess.check_call(['ssh', netloc, 'cd', path, ';', GoCmd, 'generate', './...
|
||||
print("Building")
|
||||
subprocess.check_call(['ssh', netloc, 'cd', path, ';', GoCmd, 'build', '.'])
|
||||
|
||||
print("Copying built result")
|
||||
subprocess.check_call(['scp', netloc +":"+ path + "/ollama.exe", './dist/'])
|
||||
|
||||
|
||||
|
@@ -231,8 +231,8 @@ if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\
|
||||
case $OS_NAME in
|
||||
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
|
||||
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
|
||||
fedora) install_cuda_driver_yum $OS_NAME $OS_VERSION ;;
|
||||
amzn) install_cuda_driver_yum 'fedora' '35' ;;
|
||||
fedora) [ $OS_VERSION -lt '37' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '37';;
|
||||
amzn) install_cuda_driver_yum 'fedora' '37' ;;
|
||||
debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
|
||||
ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;
|
||||
*) exit ;;
|
||||
|
43
scripts/rh_linux_deps.sh
Normal file
43
scripts/rh_linux_deps.sh
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/bin/sh
|
||||
|
||||
# Script for common Dockerfile dependency installation in redhat linux based images
|
||||
|
||||
set -ex
|
||||
MACHINE=$(uname -m)
|
||||
|
||||
if grep -i "centos" /etc/system-release >/dev/null; then
|
||||
# Centos 7 derivatives have too old of a git version to run our generate script
|
||||
# uninstall and ignore failures
|
||||
yum remove -y git
|
||||
yum -y install epel-release centos-release-scl
|
||||
yum -y install dnf
|
||||
if [ "${MACHINE}" = "x86_64" ]; then
|
||||
yum -y install https://repo.ius.io/ius-release-el7.rpm
|
||||
dnf install -y git236
|
||||
else
|
||||
dnf install -y rh-git227-git
|
||||
ln -s /opt/rh/rh-git227/root/usr/bin/git /usr/local/bin/git
|
||||
fi
|
||||
dnf install -y devtoolset-10-gcc devtoolset-10-gcc-c++
|
||||
elif grep -i "rocky" /etc/system-release >/dev/null; then
|
||||
dnf install -y git gcc-toolset-10-gcc gcc-toolset-10-gcc-c++
|
||||
else
|
||||
echo "ERROR Unexpected distro"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -n "${CMAKE_VERSION}" ]; then
|
||||
curl -s -L https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-$(uname -m).tar.gz | tar -zx -C /usr --strip-components 1
|
||||
fi
|
||||
|
||||
if [ -n "${GOLANG_VERSION}" ]; then
|
||||
if [ "${MACHINE}" = "x86_64" ]; then
|
||||
GO_ARCH="amd64"
|
||||
else
|
||||
GO_ARCH="arm64"
|
||||
fi
|
||||
mkdir -p /usr/local
|
||||
curl -s -L https://dl.google.com/go/go${GOLANG_VERSION}.linux-${GO_ARCH}.tar.gz | tar xz -C /usr/local
|
||||
ln -s /usr/local/go/bin/go /usr/local/bin/go
|
||||
ln -s /usr/local/go/bin/gofmt /usr/local/bin/gofmt
|
||||
fi
|
@@ -10,7 +10,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -86,7 +86,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
|
||||
|
||||
rawKey, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to load private key: %v", err)
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
|
||||
headers.Set("Authorization", sig)
|
||||
resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
|
||||
if err != nil {
|
||||
log.Printf("couldn't get token: %q", err)
|
||||
slog.Info(fmt.Sprintf("couldn't get token: %q", err))
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -147,12 +147,7 @@ func (s SignatureData) Bytes() []byte {
|
||||
|
||||
// SignData takes a SignatureData object and signs it with a raw private key
|
||||
func (s SignatureData) Sign(rawKey []byte) (string, error) {
|
||||
privateKey, err := ssh.ParseRawPrivateKey(rawKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signer, err := ssh.NewSignerFromKey(privateKey)
|
||||
signer, err := ssh.ParsePrivateKey(rawKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@@ -6,7 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -25,6 +25,11 @@ import (
|
||||
"github.com/jmorganca/ollama/format"
|
||||
)
|
||||
|
||||
const maxRetries = 6
|
||||
|
||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
var errPartStalled = errors.New("part stalled")
|
||||
|
||||
var blobDownloadManager sync.Map
|
||||
|
||||
type blobDownload struct {
|
||||
@@ -44,10 +49,11 @@ type blobDownload struct {
|
||||
}
|
||||
|
||||
type blobDownloadPart struct {
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed int64
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed int64
|
||||
lastUpdated time.Time
|
||||
|
||||
*blobDownload `json:"-"`
|
||||
}
|
||||
@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
|
||||
return p.Offset + p.Size
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.blobDownload.Completed.Add(int64(n))
|
||||
p.lastUpdated = time.Now()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
|
||||
if err != nil {
|
||||
@@ -98,7 +111,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
|
||||
|
||||
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||
|
||||
var size = b.Total / numDownloadParts
|
||||
size := b.Total / numDownloadParts
|
||||
switch {
|
||||
case size < minDownloadPartSize:
|
||||
size = minDownloadPartSize
|
||||
@@ -120,7 +133,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
|
||||
slog.Info(fmt.Sprintf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -132,13 +145,13 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
defer blobDownloadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
|
||||
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
file.Truncate(b.Total)
|
||||
_ = file.Truncate(b.Total)
|
||||
|
||||
g, inner := errgroup.WithContext(ctx)
|
||||
g.SetLimit(numDownloadParts)
|
||||
@@ -157,9 +170,12 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
return err
|
||||
case errors.Is(err, errPartStalled):
|
||||
try--
|
||||
continue
|
||||
case err != nil:
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
|
||||
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
default:
|
||||
@@ -195,28 +211,54 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
}
|
||||
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
n, err := io.Copy(w, io.TeeReader(resp.Body, b))
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
}
|
||||
n, err := io.Copy(w, io.TeeReader(resp.Body, part))
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
}
|
||||
|
||||
part.Completed += n
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
part.Completed += n
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
return err
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
return err
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if part.Completed >= part.Size {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
|
||||
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
part.lastUpdated = time.Time{}
|
||||
return errPartStalled
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (b *blobDownload) newPart(offset, size int64) error {
|
||||
@@ -246,7 +288,7 @@ func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
|
||||
}
|
||||
|
||||
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
|
||||
partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
|
||||
partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
|
||||
return json.NewEncoder(partFile).Encode(part)
|
||||
}
|
||||
|
||||
func (b *blobDownload) Write(p []byte) (n int, err error) {
|
||||
n = len(p)
|
||||
b.Completed.Add(int64(n))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) acquire() {
|
||||
b.references.Add(1)
|
||||
}
|
||||
@@ -279,20 +315,19 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
||||
Digest: b.Digest,
|
||||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
})
|
||||
|
||||
if b.done || b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
||||
Digest: b.Digest,
|
||||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
})
|
||||
|
||||
if b.done || b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,10 +338,6 @@ type downloadOpts struct {
|
||||
fn func(api.ProgressResponse)
|
||||
}
|
||||
|
||||
const maxRetries = 6
|
||||
|
||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
@@ -340,6 +371,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// nolint: contextcheck
|
||||
go download.Run(context.Background(), requestURL, opts.regOpts)
|
||||
}
|
||||
|
||||
|
251
server/images.go
251
server/images.go
@@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -40,7 +41,7 @@ type Model struct {
|
||||
Config ConfigV2
|
||||
ShortName string
|
||||
ModelPath string
|
||||
OriginalModel string
|
||||
ParentModel string
|
||||
AdapterPaths []string
|
||||
ProjectorPaths []string
|
||||
Template string
|
||||
@@ -49,6 +50,12 @@ type Model struct {
|
||||
Digest string
|
||||
Size int64
|
||||
Options map[string]interface{}
|
||||
Messages []Message
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PromptVars struct {
|
||||
@@ -56,6 +63,7 @@ type PromptVars struct {
|
||||
Prompt string
|
||||
Response string
|
||||
First bool
|
||||
Images []llm.ImageData
|
||||
}
|
||||
|
||||
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
||||
@@ -112,10 +120,6 @@ func Prompt(promptTemplate string, p PromptVars) (string, error) {
|
||||
|
||||
// PreResponsePrompt returns the prompt before the response tag
|
||||
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
|
||||
if p.System == "" {
|
||||
// use the default system prompt for this model if one is not specified
|
||||
p.System = m.System
|
||||
}
|
||||
pre, _, err := extractParts(m.Template)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -143,62 +147,68 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
||||
return Prompt(post, p)
|
||||
}
|
||||
|
||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
|
||||
type ChatHistory struct {
|
||||
Prompts []PromptVars
|
||||
LastSystem string
|
||||
}
|
||||
|
||||
// ChatPrompts returns a list of formatted chat prompts from a list of messages
|
||||
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||
// build the prompt from the list of messages
|
||||
var prompt strings.Builder
|
||||
var currentImages []api.ImageData
|
||||
lastSystem := m.System
|
||||
currentVars := PromptVars{
|
||||
First: true,
|
||||
System: m.System,
|
||||
}
|
||||
|
||||
writePrompt := func() error {
|
||||
p, err := Prompt(m.Template, currentVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
currentVars = PromptVars{}
|
||||
return nil
|
||||
}
|
||||
prompts := []PromptVars{}
|
||||
var images []llm.ImageData
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch strings.ToLower(msg.Role) {
|
||||
case "system":
|
||||
if currentVars.System != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
// if this is the first message it overrides the system prompt in the modelfile
|
||||
if !currentVars.First && currentVars.System != "" {
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
}
|
||||
currentVars.System = msg.Content
|
||||
lastSystem = msg.Content
|
||||
case "user":
|
||||
if currentVars.Prompt != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
}
|
||||
|
||||
currentVars.Prompt = msg.Content
|
||||
currentImages = msg.Images
|
||||
for i := range msg.Images {
|
||||
id := len(images) + i
|
||||
currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
|
||||
currentVars.Images = append(currentVars.Images, llm.ImageData{
|
||||
ID: id,
|
||||
Data: msg.Images[i],
|
||||
})
|
||||
}
|
||||
|
||||
images = append(images, currentVars.Images...)
|
||||
case "assistant":
|
||||
currentVars.Response = msg.Content
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// Append the last set of vars if they are non-empty
|
||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||
p, err := m.PreResponsePrompt(currentVars)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("pre-response template: %w", err)
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
prompts = append(prompts, currentVars)
|
||||
}
|
||||
|
||||
return prompt.String(), currentImages, nil
|
||||
return &ChatHistory{
|
||||
Prompts: prompts,
|
||||
LastSystem: lastSystem,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ManifestV2 struct {
|
||||
@@ -332,11 +342,11 @@ func GetModel(name string) (*Model, error) {
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model":
|
||||
model.ModelPath = filename
|
||||
model.OriginalModel = layer.From
|
||||
model.ParentModel = layer.From
|
||||
case "application/vnd.ollama.image.embed":
|
||||
// Deprecated in versions > 0.1.2
|
||||
// TODO: remove this warning in a future version
|
||||
log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||
case "application/vnd.ollama.image.adapter":
|
||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||
case "application/vnd.ollama.image.projector":
|
||||
@@ -373,6 +383,16 @@ func GetModel(name string) (*Model, error) {
|
||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.messages":
|
||||
msgs, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer msgs.Close()
|
||||
|
||||
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.license":
|
||||
bts, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
@@ -411,6 +431,13 @@ func realpath(mfDir, from string) string {
|
||||
}
|
||||
|
||||
func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
|
||||
for _, layer := range append(manifest.Layers, manifest.Config) {
|
||||
deleteMap[layer.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
config := ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
@@ -419,15 +446,13 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
},
|
||||
}
|
||||
|
||||
deleteMap := make(map[string]struct{})
|
||||
|
||||
var layers Layers
|
||||
messages := []string{}
|
||||
|
||||
params := make(map[string][]string)
|
||||
fromParams := make(map[string]any)
|
||||
|
||||
for _, c := range commands {
|
||||
log.Printf("[%s] - %s", c.Name, c.Args)
|
||||
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||
|
||||
switch c.Name {
|
||||
@@ -478,32 +503,6 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
return err
|
||||
}
|
||||
|
||||
// if the model is not in gguf format, pull the base model to try and get it in gguf format
|
||||
if fromConfig.ModelFormat != "gguf" {
|
||||
fn(api.ProgressResponse{Status: "updating base model"})
|
||||
parent, err := GetModel(c.Args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
originalModel := parent.OriginalModel
|
||||
if originalModel == "" {
|
||||
originalModel = parent.ShortName
|
||||
}
|
||||
if err := PullModel(ctx, originalModel, &RegistryOptions{}, fn); err != nil {
|
||||
log.Printf("error pulling parent model: %v", err)
|
||||
}
|
||||
|
||||
// Reset the file pointer to the beginning of the file
|
||||
_, err = fromConfigFile.Seek(0, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update from config after pull: %w", err)
|
||||
}
|
||||
if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// if the model is still not in gguf format, error out
|
||||
if fromConfig.ModelFormat != "gguf" {
|
||||
return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args)
|
||||
@@ -627,11 +626,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
}
|
||||
|
||||
layers.Replace(layer)
|
||||
case "message":
|
||||
messages = append(messages, c.Args)
|
||||
default:
|
||||
params[c.Name] = append(params[c.Name], c.Args)
|
||||
}
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
fn(api.ProgressResponse{Status: "creating parameters layer"})
|
||||
|
||||
msgs := make([]api.Message, 0)
|
||||
|
||||
for _, m := range messages {
|
||||
// todo: handle images
|
||||
msg := strings.SplitN(m, ": ", 2)
|
||||
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(msgs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layers.Replace(layer)
|
||||
}
|
||||
|
||||
if len(params) > 0 {
|
||||
fn(api.ProgressResponse{Status: "creating parameters layer"})
|
||||
|
||||
@@ -773,6 +798,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{},
|
||||
// save (i.e. delete from the deleteMap) any files used in other manifests
|
||||
manifest, _, err := GetManifest(fmp)
|
||||
if err != nil {
|
||||
// nolint: nilerr
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -792,16 +818,16 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{},
|
||||
for k := range deleteMap {
|
||||
fp, err := GetBlobsPath(k)
|
||||
if err != nil {
|
||||
log.Printf("couldn't get file path for '%s': %v", k, err)
|
||||
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
|
||||
continue
|
||||
}
|
||||
if !dryRun {
|
||||
if err := os.Remove(fp); err != nil {
|
||||
log.Printf("couldn't remove file '%s': %v", fp, err)
|
||||
slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err))
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
log.Printf("wanted to remove: %s", fp)
|
||||
slog.Info(fmt.Sprintf("wanted to remove: %s", fp))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -817,7 +843,7 @@ func PruneLayers() error {
|
||||
|
||||
blobs, err := os.ReadDir(p)
|
||||
if err != nil {
|
||||
log.Printf("couldn't read dir '%s': %v", p, err)
|
||||
slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -831,14 +857,14 @@ func PruneLayers() error {
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("total blobs: %d", len(deleteMap))
|
||||
slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap)))
|
||||
|
||||
err = deleteUnusedLayers(nil, deleteMap, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("total unused blobs removed: %d", len(deleteMap))
|
||||
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -900,7 +926,7 @@ func DeleteModel(name string) error {
|
||||
}
|
||||
err = os.Remove(fp)
|
||||
if err != nil {
|
||||
log.Printf("couldn't remove manifest file '%s': %v", fp, err)
|
||||
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -927,8 +953,8 @@ func ShowModelfile(model *Model) (string, error) {
|
||||
mt.Model = model
|
||||
mt.From = model.ModelPath
|
||||
|
||||
if model.OriginalModel != "" {
|
||||
mt.From = model.OriginalModel
|
||||
if model.ParentModel != "" {
|
||||
mt.From = model.ParentModel
|
||||
}
|
||||
|
||||
modelFile := `# Modelfile generated by "ollama show"
|
||||
@@ -954,14 +980,14 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
|
||||
|
||||
tmpl, err := template.New("").Parse(modelFile)
|
||||
if err != nil {
|
||||
log.Printf("error parsing template: %q", err)
|
||||
slog.Info(fmt.Sprintf("error parsing template: %q", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err = tmpl.Execute(&buf, mt); err != nil {
|
||||
log.Printf("error executing template: %q", err)
|
||||
slog.Info(fmt.Sprintf("error executing template: %q", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -988,7 +1014,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
log.Printf("error uploading blob: %v", err)
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
if errors.Is(err, errUnauthorized) {
|
||||
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
|
||||
}
|
||||
@@ -1083,7 +1109,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
}
|
||||
if err := os.Remove(fp); err != nil {
|
||||
// log this, but return the original error
|
||||
log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
|
||||
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
|
||||
}
|
||||
}
|
||||
return err
|
||||
@@ -1107,7 +1133,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
|
||||
err = os.WriteFile(fp, manifestJSON, 0o644)
|
||||
if err != nil {
|
||||
log.Printf("couldn't write to %s", fp)
|
||||
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1157,49 +1183,46 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
var errUnauthorized = fmt.Errorf("unauthorized")
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
log.Printf("request failed: %v", err)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
// Handle authentication error with one retry
|
||||
auth := resp.Header.Get("www-authenticate")
|
||||
authRedir := ParseAuthRedirectString(auth)
|
||||
token, err := getAuthToken(ctx, authRedir)
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
slog.Info(fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
regOpts.Token = token
|
||||
if body != nil {
|
||||
_, err = body.Seek(0, io.SeekStart)
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
// Handle authentication error with one retry
|
||||
auth := resp.Header.Get("www-authenticate")
|
||||
authRedir := ParseAuthRedirectString(auth)
|
||||
token, err := getAuthToken(ctx, authRedir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
regOpts.Token = token
|
||||
if body != nil {
|
||||
_, err = body.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
case resp.StatusCode == http.StatusNotFound:
|
||||
return nil, os.ErrNotExist
|
||||
case resp.StatusCode >= http.StatusBadRequest:
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
||||
}
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
||||
default:
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
|
||||
return resp, err
|
||||
case resp.StatusCode == http.StatusNotFound:
|
||||
return nil, os.ErrNotExist
|
||||
case resp.StatusCode >= http.StatusBadRequest:
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
||||
}
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
|
||||
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -233,17 +234,58 @@ func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func chatHistoryEqual(a, b ChatHistory) bool {
|
||||
if len(a.Prompts) != len(b.Prompts) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a.Prompts {
|
||||
|
||||
if v.First != b.Prompts[i].First {
|
||||
return false
|
||||
}
|
||||
|
||||
if v.Response != b.Prompts[i].Response {
|
||||
return false
|
||||
}
|
||||
|
||||
if v.Prompt != b.Prompts[i].Prompt {
|
||||
return false
|
||||
}
|
||||
|
||||
if v.System != b.Prompts[i].System {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(v.Images) != len(b.Prompts[i].Images) {
|
||||
return false
|
||||
}
|
||||
|
||||
for j, img := range v.Images {
|
||||
if img.ID != b.Prompts[i].Images[j].ID {
|
||||
return false
|
||||
}
|
||||
|
||||
if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return a.LastSystem == b.LastSystem
|
||||
}
|
||||
|
||||
func TestChat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
msgs []api.Message
|
||||
want string
|
||||
wantErr string
|
||||
name string
|
||||
model Model
|
||||
msgs []api.Message
|
||||
want ChatHistory
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "Single Message",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Single Message",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
@@ -254,34 +296,22 @@ func TestChat(t *testing.T) {
|
||||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First Message",
|
||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a Wizard.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "eye of newt",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Anything else?",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Message History",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
@@ -300,18 +330,85 @@ func TestChat(t *testing.T) {
|
||||
Content: "Anything else?",
|
||||
},
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Assistant Only",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Assistant Only",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "everything nice",
|
||||
},
|
||||
},
|
||||
want: "[INST] [/INST]everything nice",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Response: "everything nice",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Last system message is preserved from modelfile",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
System: "You are Mojo Jojo.",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hi",
|
||||
},
|
||||
},
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are Mojo Jojo.",
|
||||
Prompt: "hi",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are Mojo Jojo.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Last system message is preserved from messages",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
System: "You are Mojo Jojo.",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are Professor Utonium.",
|
||||
},
|
||||
},
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are Professor Utonium.",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are Professor Utonium.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid Role",
|
||||
@@ -326,11 +423,8 @@ func TestChat(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
m := Model{
|
||||
Template: tt.template,
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, _, err := m.ChatPrompt(tt.msgs)
|
||||
got, err := tt.model.ChatPrompts(tt.msgs)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("ChatPrompt() expected error, got nil")
|
||||
@@ -338,9 +432,10 @@ func TestChat(t *testing.T) {
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
||||
if !chatHistoryEqual(*got, tt.want) {
|
||||
t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@@ -26,9 +26,9 @@ func WriteManifest(name string, config *Layer, layers []*Layer) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(manifestPath, b.Bytes(), 0644)
|
||||
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
|
||||
}
|
||||
|
@@ -46,7 +46,8 @@ func ParseModelPath(name string) ModelPath {
|
||||
name = after
|
||||
}
|
||||
|
||||
parts := strings.Split(name, string(os.PathSeparator))
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
mp.Registry = parts[0]
|
||||
|
361
server/routes.go
361
server/routes.go
@@ -7,7 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -74,7 +73,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
|
||||
|
||||
if needLoad {
|
||||
if loaded.runner != nil {
|
||||
log.Println("changing loaded model")
|
||||
slog.Info("changing loaded model")
|
||||
loaded.runner.Close()
|
||||
loaded.runner = nil
|
||||
loaded.Model = nil
|
||||
@@ -187,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration := defaultSessionDuration
|
||||
var sessionDuration time.Duration
|
||||
if req.KeepAlive == nil {
|
||||
sessionDuration = defaultSessionDuration
|
||||
} else {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -198,7 +203,8 @@ func GenerateHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true})
|
||||
Done: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -233,6 +239,15 @@ func GenerateHandler(c *gin.Context) {
|
||||
Prompt: req.Prompt,
|
||||
First: len(req.Context) == 0,
|
||||
}
|
||||
|
||||
if promptVars.System == "" {
|
||||
promptVars.System = model.System
|
||||
}
|
||||
|
||||
for i := range req.Images {
|
||||
promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
|
||||
}
|
||||
|
||||
p, err := model.PreResponsePrompt(promptVars)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -242,6 +257,8 @@ func GenerateHandler(c *gin.Context) {
|
||||
prompt = rebuild.String()
|
||||
}
|
||||
|
||||
slog.Debug("generate handler", "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
var generated strings.Builder
|
||||
go func() {
|
||||
@@ -295,11 +312,19 @@ func GenerateHandler(c *gin.Context) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
var images []llm.ImageData
|
||||
for i := range req.Images {
|
||||
images = append(images, llm.ImageData{
|
||||
ID: i,
|
||||
Data: req.Images[i],
|
||||
})
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: req.Images,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
@@ -378,7 +403,14 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sessionDuration := defaultSessionDuration
|
||||
|
||||
var sessionDuration time.Duration
|
||||
if req.KeepAlive == nil {
|
||||
sessionDuration = defaultSessionDuration
|
||||
} else {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -391,7 +423,7 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
|
||||
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
log.Printf("embedding generation failed: %v", err)
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
@@ -414,8 +446,13 @@ func PullModelHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
var model string
|
||||
if req.Model != "" {
|
||||
model = req.Model
|
||||
} else if req.Name != "" {
|
||||
model = req.Name
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -433,7 +470,7 @@ func PullModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := PullModel(ctx, req.Name, regOpts, fn); err != nil {
|
||||
if err := PullModel(ctx, model, regOpts, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -458,8 +495,13 @@ func PushModelHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
var model string
|
||||
if req.Model != "" {
|
||||
model = req.Model
|
||||
} else if req.Name != "" {
|
||||
model = req.Name
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -477,7 +519,7 @@ func PushModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
|
||||
if err := PushModel(ctx, model, regOpts, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -502,12 +544,17 @@ func CreateModelHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
var model string
|
||||
if req.Model != "" {
|
||||
model = req.Model
|
||||
} else if req.Name != "" {
|
||||
model = req.Name
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := ParseModelPath(req.Name).Validate(); err != nil {
|
||||
if err := ParseModelPath(model).Validate(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -545,7 +592,7 @@ func CreateModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := CreateModel(ctx, req.Name, filepath.Dir(req.Path), commands, fn); err != nil {
|
||||
if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -570,14 +617,19 @@ func DeleteModelHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
var model string
|
||||
if req.Model != "" {
|
||||
model = req.Model
|
||||
} else if req.Name != "" {
|
||||
model = req.Name
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := DeleteModel(req.Name); err != nil {
|
||||
if err := DeleteModel(model); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
@@ -610,15 +662,19 @@ func ShowModelHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
if req.Model != "" {
|
||||
// noop
|
||||
} else if req.Name != "" {
|
||||
req.Model = req.Name
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := GetModelInfo(req.Name)
|
||||
resp, err := GetModelInfo(req)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
@@ -628,13 +684,14 @@ func ShowModelHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func GetModelInfo(name string) (*api.ShowResponse, error) {
|
||||
model, err := GetModel(name)
|
||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelDetails := api.ModelDetails{
|
||||
ParentModel: model.ParentModel,
|
||||
Format: model.Config.ModelFormat,
|
||||
Family: model.Config.ModelFamily,
|
||||
Families: model.Config.ModelFamilies,
|
||||
@@ -642,11 +699,45 @@ func GetModelInfo(name string) (*api.ShowResponse, error) {
|
||||
QuantizationLevel: model.Config.FileType,
|
||||
}
|
||||
|
||||
if req.System != "" {
|
||||
model.System = req.System
|
||||
}
|
||||
|
||||
if req.Template != "" {
|
||||
model.Template = req.Template
|
||||
}
|
||||
|
||||
msgs := make([]api.Message, 0)
|
||||
for _, msg := range model.Messages {
|
||||
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
|
||||
}
|
||||
|
||||
resp := &api.ShowResponse{
|
||||
License: strings.Join(model.License, "\n"),
|
||||
System: model.System,
|
||||
Template: model.Template,
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
}
|
||||
|
||||
var params []string
|
||||
cs := 30
|
||||
for k, v := range model.Options {
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
for _, nv := range val {
|
||||
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
|
||||
}
|
||||
default:
|
||||
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
|
||||
}
|
||||
}
|
||||
resp.Parameters = strings.Join(params, "\n")
|
||||
|
||||
for k, v := range req.Options {
|
||||
if _, ok := req.Options[k]; ok {
|
||||
model.Options[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
mf, err := ShowModelfile(model)
|
||||
@@ -656,41 +747,12 @@ func GetModelInfo(name string) (*api.ShowResponse, error) {
|
||||
|
||||
resp.Modelfile = mf
|
||||
|
||||
var params []string
|
||||
cs := 30
|
||||
for k, v := range model.Options {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
params = append(params, fmt.Sprintf("%-*s %s", cs, k, val))
|
||||
case int:
|
||||
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(val)))
|
||||
case float64:
|
||||
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(val, 'f', 0, 64)))
|
||||
case bool:
|
||||
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(val)))
|
||||
case []interface{}:
|
||||
for _, nv := range val {
|
||||
switch nval := nv.(type) {
|
||||
case string:
|
||||
params = append(params, fmt.Sprintf("%-*s %s", cs, k, nval))
|
||||
case int:
|
||||
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(nval)))
|
||||
case float64:
|
||||
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(nval, 'f', 0, 64)))
|
||||
case bool:
|
||||
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(nval)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
resp.Parameters = strings.Join(params, "\n")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func ListModelsHandler(c *gin.Context) {
|
||||
models := make([]api.ModelResponse, 0)
|
||||
fp, err := GetManifestPath()
|
||||
manifestsPath, err := GetManifestPath()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -711,6 +773,7 @@ func ListModelsHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
return api.ModelResponse{
|
||||
Model: model.ShortName,
|
||||
Name: model.ShortName,
|
||||
Size: model.Size,
|
||||
Digest: model.Digest,
|
||||
@@ -720,13 +783,15 @@ func ListModelsHandler(c *gin.Context) {
|
||||
|
||||
walkFunc := func(path string, info os.FileInfo, _ error) error {
|
||||
if !info.IsDir() {
|
||||
dir, file := filepath.Split(path)
|
||||
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
|
||||
tag := strings.Join([]string{dir, file}, ":")
|
||||
path, tag := filepath.Split(path)
|
||||
model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator))
|
||||
modelPath := strings.Join([]string{model, tag}, ":")
|
||||
canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/")
|
||||
|
||||
resp, err := modelResponse(tag)
|
||||
resp, err := modelResponse(canonicalModelPath)
|
||||
if err != nil {
|
||||
log.Printf("skipping file: %s", fp)
|
||||
slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath))
|
||||
// nolint: nilerr
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -737,7 +802,7 @@ func ListModelsHandler(c *gin.Context) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := filepath.Walk(fp, walkFunc); err != nil {
|
||||
if err := filepath.Walk(manifestsPath, walkFunc); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -837,6 +902,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowWildcard = true
|
||||
config.AllowBrowserExtensions = true
|
||||
|
||||
config.AllowOrigins = origins
|
||||
for _, allowOrigin := range defaultAllowOrigins {
|
||||
@@ -884,6 +950,26 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
level := slog.LevelInfo
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: true,
|
||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
||||
if attr.Key == slog.SourceKey {
|
||||
source := attr.Value.Any().(*slog.Source)
|
||||
source.File = filepath.Base(source.File)
|
||||
}
|
||||
|
||||
return attr
|
||||
},
|
||||
})
|
||||
|
||||
slog.SetDefault(slog.New(handler))
|
||||
|
||||
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
||||
// clean up unused layers and manifests
|
||||
if err := PruneLayers(); err != nil {
|
||||
@@ -906,7 +992,7 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
r := s.GenerateRoutes()
|
||||
|
||||
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
|
||||
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
||||
srvr := &http.Server{
|
||||
Handler: r,
|
||||
}
|
||||
@@ -929,7 +1015,7 @@ func Serve(ln net.Listener) error {
|
||||
if runtime.GOOS == "linux" { // TODO - windows too
|
||||
// check compatibility to log warnings
|
||||
if _, err := gpu.CheckVRAM(); err != nil {
|
||||
log.Print(err.Error())
|
||||
slog.Info(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -971,14 +1057,14 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
|
||||
bts, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
log.Printf("streamResponse: json.Marshal failed with %s", err)
|
||||
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
|
||||
return false
|
||||
}
|
||||
|
||||
// Delineate chunks with new-line delimiter
|
||||
bts = append(bts, '\n')
|
||||
if _, err := w.Write(bts); err != nil {
|
||||
log.Printf("streamResponse: w.Write failed with %s", err)
|
||||
slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1033,7 +1119,14 @@ func ChatHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sessionDuration := defaultSessionDuration
|
||||
|
||||
var sessionDuration time.Duration
|
||||
if req.KeepAlive == nil {
|
||||
sessionDuration = defaultSessionDuration
|
||||
} else {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1041,18 +1134,32 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
// an empty request loads the model
|
||||
if len(req.Messages) == 0 {
|
||||
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}})
|
||||
resp := api.ChatResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
Message: api.Message{Role: "assistant"},
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, images, err := model.ChatPrompt(req.Messages)
|
||||
chat, err := model.ChatPrompts(req.Messages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("chat handler", "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
|
||||
go func() {
|
||||
@@ -1126,3 +1233,115 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
|
||||
type promptInfo struct {
|
||||
vars PromptVars
|
||||
tokenLen int
|
||||
}
|
||||
|
||||
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
||||
// while preserving the most recent system message.
|
||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
|
||||
if len(chat.Prompts) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
var promptsToAdd []promptInfo
|
||||
var totalTokenLength int
|
||||
var systemPromptIncluded bool
|
||||
|
||||
var images []llm.ImageData
|
||||
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
||||
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
||||
prompt := chat.Prompts[i]
|
||||
promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
||||
break // reached max context length, stop adding more prompts
|
||||
}
|
||||
|
||||
for j := range prompt.Images {
|
||||
if totalTokenLength+768 > loaded.NumCtx {
|
||||
// this decreases the token length but overestimating is fine
|
||||
prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
|
||||
continue
|
||||
}
|
||||
|
||||
totalTokenLength += 768
|
||||
images = append(images, prompt.Images[j])
|
||||
}
|
||||
|
||||
totalTokenLength += len(encodedTokens)
|
||||
systemPromptIncluded = systemPromptIncluded || prompt.System != ""
|
||||
promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
|
||||
}
|
||||
|
||||
// ensure the system prompt is included, if not already
|
||||
if chat.LastSystem != "" && !systemPromptIncluded {
|
||||
var err error
|
||||
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
promptsToAdd[len(promptsToAdd)-1].vars.First = true
|
||||
|
||||
// construct the final prompt string from the prompts which fit within the context window
|
||||
var result string
|
||||
for i, prompt := range promptsToAdd {
|
||||
promptText, err := promptString(model, prompt.vars, i == 0)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
result = promptText + result
|
||||
}
|
||||
|
||||
return result, images, nil
|
||||
}
|
||||
|
||||
// promptString applies the model template to the prompt
|
||||
func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
|
||||
if isMostRecent {
|
||||
p, err := model.PreResponsePrompt(vars)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("pre-response template: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
p, err := Prompt(model.Template, vars)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// includeSystemPrompt adjusts the prompts to include the system prompt.
|
||||
func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
|
||||
systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := len(promptsToAdd) - 1; i >= 0; i-- {
|
||||
if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
|
||||
promptsToAdd[i].vars.System = systemPrompt
|
||||
return promptsToAdd[:i+1], nil
|
||||
}
|
||||
totalTokenLength -= promptsToAdd[i].tokenLen
|
||||
}
|
||||
|
||||
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
|
||||
recent := promptsToAdd[len(promptsToAdd)-1]
|
||||
recent.vars.System = systemPrompt
|
||||
return []promptInfo{recent}, nil
|
||||
}
|
||||
|
@@ -9,12 +9,14 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
@@ -50,7 +52,7 @@ func Test_Routes(t *testing.T) {
|
||||
createTestModel := func(t *testing.T, name string) {
|
||||
fname := createTestFile(t, "ollama-model")
|
||||
|
||||
modelfile := strings.NewReader(fmt.Sprintf("FROM %s", fname))
|
||||
modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
|
||||
commands, err := parser.Parse(modelfile)
|
||||
assert.Nil(t, err)
|
||||
fn := func(resp api.ProgressResponse) {
|
||||
@@ -167,6 +169,42 @@ func Test_Routes(t *testing.T) {
|
||||
assert.Equal(t, "beefsteak:latest", model.ShortName)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Show Model Handler",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/show",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
createTestModel(t, "show-model")
|
||||
showReq := api.ShowRequest{Model: "show-model"}
|
||||
jsonData, err := json.Marshal(showReq)
|
||||
assert.Nil(t, err)
|
||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var showResp api.ShowResponse
|
||||
err = json.Unmarshal(body, &showResp)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var params []string
|
||||
paramsSplit := strings.Split(showResp.Parameters, "\n")
|
||||
for _, p := range paramsSplit {
|
||||
params = append(params, strings.Join(strings.Fields(p), " "))
|
||||
}
|
||||
sort.Strings(params)
|
||||
expectedParams := []string{
|
||||
"seed 42",
|
||||
"stop \"bar\"",
|
||||
"stop \"foo\"",
|
||||
"top_p 0.9",
|
||||
}
|
||||
assert.Equal(t, expectedParams, params)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s, err := setupServer(t)
|
||||
@@ -193,13 +231,267 @@ func Test_Routes(t *testing.T) {
|
||||
}
|
||||
|
||||
resp, err := httpSrv.Client().Do(req)
|
||||
defer resp.Body.Close()
|
||||
assert.Nil(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if tc.Expected != nil {
|
||||
tc.Expected(t, resp)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Test_ChatPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
chat *ChatHistory
|
||||
numCtx int
|
||||
runner MockLLM
|
||||
want string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "Single Message",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First Message",
|
||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "eye of newt",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 2,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen
|
||||
},
|
||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 4,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen, 1 for each message
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Assistant Only",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Response: "everything nice",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] [/INST]everything nice",
|
||||
},
|
||||
{
|
||||
name: "Message History Truncated, No System",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
Response: "spice",
|
||||
},
|
||||
{
|
||||
Prompt: "... and?",
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 2, // only 1 message from history and most recent message
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "System is Preserved when Truncated",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the magic words?",
|
||||
Response: "abracadabra",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 2,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "System is Preserved when Length Exceeded",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the magic words?",
|
||||
Response: "abracadabra",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First is Preserved when Truncated",
|
||||
template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
|
||||
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
// first message omitted for test
|
||||
{
|
||||
Prompt: "Do you have a magic hat?",
|
||||
Response: "Of course.",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 3, // two most recent messages and room for system message
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Most recent message is returned when longer than ctxLen",
|
||||
template: "[INST] {{ .Prompt }} [/INST]",
|
||||
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 1, // two most recent messages
|
||||
runner: MockLLM{
|
||||
encoding: []int{1, 2},
|
||||
},
|
||||
want: "[INST] What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range tests {
|
||||
tt := testCase
|
||||
m := &Model{
|
||||
Template: tt.template,
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
loaded.runner = &tt.runner
|
||||
loaded.Options = &api.Options{
|
||||
Runner: api.Runner{
|
||||
NumCtx: tt.numCtx,
|
||||
},
|
||||
}
|
||||
// TODO: add tests for trimming images
|
||||
got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("ChatPrompt() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MockLLM struct {
|
||||
encoding []int
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
return llm.encoding, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
return []float64{}, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Close() {
|
||||
// do nothing
|
||||
}
|
||||
|
@@ -7,7 +7,7 @@ import (
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -88,7 +88,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
|
||||
return nil
|
||||
}
|
||||
|
||||
var size = b.Total / numUploadParts
|
||||
size := b.Total / numUploadParts
|
||||
switch {
|
||||
case size < minUploadPartSize:
|
||||
size = minUploadPartSize
|
||||
@@ -107,7 +107,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
|
||||
offset += size
|
||||
}
|
||||
|
||||
log.Printf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
|
||||
slog.Info(fmt.Sprintf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
|
||||
|
||||
requestURL, err = url.Parse(location)
|
||||
if err != nil {
|
||||
@@ -156,7 +156,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
|
||||
return err
|
||||
case err != nil:
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
|
||||
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
}
|
||||
@@ -200,7 +200,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
|
||||
break
|
||||
} else if err != nil {
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
log.Printf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep)
|
||||
slog.Info(fmt.Sprintf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep))
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
}
|
||||
@@ -265,7 +265,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
return err
|
||||
case err != nil:
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
|
||||
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
}
|
||||
@@ -395,6 +395,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryO
|
||||
return err
|
||||
}
|
||||
|
||||
// nolint: contextcheck
|
||||
go upload.Run(context.Background(), opts)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user