new runner
This commit is contained in:
parent
a34960a516
commit
0d22c0ec1a
373
cache/cache.go
vendored
373
cache/cache.go
vendored
@ -1,63 +1,384 @@
|
|||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
var ErrNotSupported = errors.New("model does not support operation")
|
||||||
Position int
|
|
||||||
}
|
|
||||||
|
|
||||||
type Cache interface {
|
type Cache interface {
|
||||||
|
// used by model implementations
|
||||||
Sub(i int) Cache
|
Sub(i int) Cache
|
||||||
Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor)
|
Put(ctx ml.Context, key, value ml.Tensor) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||||
|
|
||||||
|
// cache management
|
||||||
|
Close()
|
||||||
|
|
||||||
|
StartForward(ctx ml.Context, seqs []int) error
|
||||||
|
|
||||||
|
CopyPrefix(srcSeq, dstSeq int, len int)
|
||||||
|
Remove(seq int, beginIndex, endIndex int) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Simple struct {
|
type Causal struct {
|
||||||
DType ml.DType
|
DType ml.DType
|
||||||
Capacity int
|
Capacity int
|
||||||
|
|
||||||
|
// current forward pass
|
||||||
|
curLayer int
|
||||||
|
curPos int
|
||||||
|
curBatchSize int
|
||||||
|
curMask ml.Tensor
|
||||||
|
curCellRange cellRange
|
||||||
|
|
||||||
|
// metadata
|
||||||
|
cells []cacheCell
|
||||||
|
seqNextPos map[int]int
|
||||||
|
cellRanges map[int]cellRange
|
||||||
|
|
||||||
|
// cache data storage
|
||||||
|
backend ml.Backend
|
||||||
|
cacheCtx ml.Context
|
||||||
keys, values []ml.Tensor
|
keys, values []ml.Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Simple) Sub(i int) Cache {
|
type seqCell struct {
|
||||||
|
seq int
|
||||||
|
pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheCell struct {
|
||||||
|
sequences []seqCell
|
||||||
|
}
|
||||||
|
|
||||||
|
type cellRange struct {
|
||||||
|
min int
|
||||||
|
max int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cell cacheCell) findSeq(seq int) *seqCell {
|
||||||
|
for i := range cell.sequences {
|
||||||
|
if cell.sequences[i].seq == seq {
|
||||||
|
return &cell.sequences[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCausalCache(backend ml.Backend, capacity int, dtype ml.DType) Cache {
|
||||||
|
return &Causal{
|
||||||
|
Capacity: capacity,
|
||||||
|
DType: dtype,
|
||||||
|
cells: make([]cacheCell, capacity),
|
||||||
|
seqNextPos: make(map[int]int),
|
||||||
|
cellRanges: make(map[int]cellRange),
|
||||||
|
backend: backend,
|
||||||
|
// TODO(jessegross): This context is not sized appropriately
|
||||||
|
cacheCtx: backend.NewContext(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Close() {
|
||||||
|
c.cacheCtx.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||||
|
|
||||||
|
func (c *Causal) StartForward(ctx ml.Context, seqs []int) error {
|
||||||
|
c.curBatchSize = len(seqs)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
c.curPos, err = c.findStartPos()
|
||||||
|
if errors.Is(err, ErrKvCacheFull) {
|
||||||
|
c.defrag()
|
||||||
|
c.curPos, err = c.findStartPos()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(jessegross): There should be a better way to do this
|
||||||
|
origSeq := make(map[int]int)
|
||||||
|
for k, v := range c.seqNextPos {
|
||||||
|
origSeq[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
c.curCellRange = newRange()
|
||||||
|
for i, seq := range seqs {
|
||||||
|
c.cells[c.curPos+i] = cacheCell{sequences: []seqCell{{seq: seq, pos: c.seqNextPos[seq]}}}
|
||||||
|
c.seqNextPos[seq]++
|
||||||
|
|
||||||
|
ranges := c.cellRanges[seq]
|
||||||
|
if c.curPos+i > ranges.max {
|
||||||
|
ranges.max = c.curPos + i
|
||||||
|
}
|
||||||
|
if ranges.max > c.curCellRange.max {
|
||||||
|
c.curCellRange.max = ranges.max
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.curPos+i < ranges.min {
|
||||||
|
ranges.min = c.curPos + i
|
||||||
|
}
|
||||||
|
if ranges.min < c.curCellRange.min {
|
||||||
|
c.curCellRange.min = ranges.min
|
||||||
|
}
|
||||||
|
c.cellRanges[seq] = ranges
|
||||||
|
}
|
||||||
|
|
||||||
|
c.curMask, err = c.buildMask(ctx, origSeq, seqs)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRange() cellRange {
|
||||||
|
return cellRange{
|
||||||
|
min: math.MaxInt,
|
||||||
|
max: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) findStartPos() (int, error) {
|
||||||
|
var start, count int
|
||||||
|
for i := range c.cells {
|
||||||
|
if len(c.cells[i].sequences) == 0 {
|
||||||
|
count++
|
||||||
|
if count >= c.curBatchSize {
|
||||||
|
return start, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
start = i + 1
|
||||||
|
count = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) buildMask(ctx ml.Context, origSeq map[int]int, seqs []int) (ml.Tensor, error) {
|
||||||
|
// TODO(jessegross): This makes a number of simplifications such as no padding
|
||||||
|
len := c.curCellRange.max - c.curCellRange.min
|
||||||
|
mask := make([]float32, c.curBatchSize*len)
|
||||||
|
|
||||||
|
for i := range c.curBatchSize {
|
||||||
|
for j := c.curCellRange.min; j < c.curCellRange.max; j++ {
|
||||||
|
cellSeq := c.cells[j].findSeq(seqs[i])
|
||||||
|
if cellSeq == nil || cellSeq.pos > origSeq[seqs[i]]+i {
|
||||||
|
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
|
||||||
|
for _, obj := range objs {
|
||||||
|
srcView := obj.View(ctx, int(obj.Stride(2))*src, int(obj.Dim(0)*obj.Dim(1))*len)
|
||||||
|
dstView := obj.View(ctx, int(obj.Stride(2))*dst, int(obj.Dim(0)*obj.Dim(1))*len)
|
||||||
|
|
||||||
|
ctx.Forward(srcView.Copy(ctx, dstView))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) defrag() {
|
||||||
|
slog.Debug("defragmenting kv cache")
|
||||||
|
|
||||||
|
// Defrag strategy:
|
||||||
|
// - Search for empty holes at the beginning of the cache,
|
||||||
|
// filling them with active data starting at the end
|
||||||
|
// - If there are contiguous elements that need to be moved,
|
||||||
|
// combine them into a single operation by holding new moves
|
||||||
|
// until we see the next one is non-contiguous
|
||||||
|
// - Fill up the context with the maximum number of operations it
|
||||||
|
// can hold then compute that and continue with a new context
|
||||||
|
|
||||||
|
// TODO(jessegross):
|
||||||
|
// - Need to size the context and compute maxMoves correctly
|
||||||
|
// - Just compacts, doesn't optimize placement
|
||||||
|
maxMoves := 8192 / (6 * len(c.keys))
|
||||||
|
|
||||||
|
ctx := c.backend.NewContext()
|
||||||
|
moves := 0
|
||||||
|
|
||||||
|
var pendingSrc, pendingDst, pendingLen int
|
||||||
|
|
||||||
|
for dst := range c.cells {
|
||||||
|
if len(c.cells[dst].sequences) == 0 {
|
||||||
|
for src := len(c.cells) - 1; src > dst; src-- {
|
||||||
|
if len(c.cells[src].sequences) != 0 {
|
||||||
|
c.cells[dst] = c.cells[src]
|
||||||
|
c.cells[src] = cacheCell{}
|
||||||
|
|
||||||
|
if pendingLen > 0 {
|
||||||
|
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
|
||||||
|
pendingSrc = src
|
||||||
|
pendingLen++
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||||
|
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||||
|
moves++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingSrc = src
|
||||||
|
pendingDst = dst
|
||||||
|
pendingLen = 1
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if moves >= maxMoves {
|
||||||
|
ctx.Compute(nil)
|
||||||
|
ctx.Close()
|
||||||
|
ctx = c.backend.NewContext()
|
||||||
|
|
||||||
|
moves = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pendingLen > 0 {
|
||||||
|
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||||
|
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||||
|
moves++
|
||||||
|
}
|
||||||
|
|
||||||
|
if moves > 0 {
|
||||||
|
ctx.Compute(nil)
|
||||||
|
}
|
||||||
|
ctx.Close()
|
||||||
|
|
||||||
|
for seq := range c.cellRanges {
|
||||||
|
seqRange := newRange()
|
||||||
|
|
||||||
|
for i, cell := range c.cells {
|
||||||
|
if cell.findSeq(seq) != nil {
|
||||||
|
if i < seqRange.min {
|
||||||
|
seqRange.min = i
|
||||||
|
}
|
||||||
|
if i > seqRange.max {
|
||||||
|
seqRange.max = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = seqRange
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Sub(i int) Cache {
|
||||||
if i >= len(c.keys) {
|
if i >= len(c.keys) {
|
||||||
c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
|
c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
|
||||||
c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
|
c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Simple{
|
c.curLayer = i
|
||||||
keys: c.keys[i : i+1],
|
|
||||||
values: c.values[i : i+1],
|
return c
|
||||||
Capacity: c.Capacity,
|
|
||||||
DType: c.DType,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) {
|
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
if c.keys[0] == nil || c.values[0] == nil {
|
if c.curBatchSize != int(key.Dim(2)) {
|
||||||
c.keys[0] = ctx.Zeros(c.DType, int(key.Dim(0)*key.Dim(1))*c.Capacity)
|
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, int(key.Dim(2))))
|
||||||
c.values[0] = ctx.Zeros(c.DType, int(value.Dim(0)*value.Dim(1))*c.Capacity)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, int(key.Stride(2))*opts.Position, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
|
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||||
ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, int(value.Stride(2))*opts.Position, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
|
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int64(c.Capacity))
|
||||||
|
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int64(c.Capacity))
|
||||||
|
}
|
||||||
|
|
||||||
n := min(c.Capacity, int(key.Dim(2))+opts.Position)
|
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, int(key.Stride(2))*c.curPos, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
|
||||||
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, int(value.Stride(2))*c.curPos, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
|
||||||
|
|
||||||
key = c.keys[0].View(ctx, 0,
|
len := c.curCellRange.max - c.curCellRange.min
|
||||||
|
|
||||||
|
key = c.keys[c.curLayer].View(ctx, int(key.Stride(2))*c.curCellRange.min,
|
||||||
int(key.Dim(0)), int(key.Stride(1)),
|
int(key.Dim(0)), int(key.Stride(1)),
|
||||||
int(key.Dim(1)), int(key.Stride(2)),
|
int(key.Dim(1)), int(key.Stride(2)),
|
||||||
n,
|
len,
|
||||||
)
|
)
|
||||||
|
|
||||||
value = c.values[0].View(ctx, 0,
|
value = c.values[c.curLayer].View(ctx, int(key.Stride(2))*c.curCellRange.min,
|
||||||
int(value.Dim(0)), int(value.Stride(1)),
|
int(value.Dim(0)), int(value.Stride(1)),
|
||||||
int(value.Dim(1)), int(value.Stride(2)),
|
int(value.Dim(1)), int(value.Stride(2)),
|
||||||
n,
|
len,
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO shift context if necessary
|
return key, value, c.curMask
|
||||||
|
}
|
||||||
return key, value
|
|
||||||
|
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int) {
|
||||||
|
seqRange := newRange()
|
||||||
|
|
||||||
|
for i := range c.cells {
|
||||||
|
srcCellSeq := c.cells[i].findSeq(srcSeq)
|
||||||
|
dstCellSeq := c.cells[i].findSeq(dstSeq)
|
||||||
|
|
||||||
|
if dstCellSeq != nil {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s seqCell) bool { return s.seq == dstSeq })
|
||||||
|
}
|
||||||
|
|
||||||
|
if srcCellSeq != nil && srcCellSeq.pos < len {
|
||||||
|
c.cells[i].sequences = append(c.cells[i].sequences, seqCell{seq: dstSeq, pos: srcCellSeq.pos})
|
||||||
|
if i < seqRange.min {
|
||||||
|
seqRange.min = i
|
||||||
|
}
|
||||||
|
if i > seqRange.max {
|
||||||
|
seqRange.max = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[dstSeq] = seqRange
|
||||||
|
c.seqNextPos[dstSeq] = len
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) shift(seq int, beginIndex, endIndex, offset int) error {
|
||||||
|
panic("Shift not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Remove(seq int, beginIndex, endIndex int) error {
|
||||||
|
endIndex = min(endIndex, c.seqNextPos[seq])
|
||||||
|
offset := beginIndex - endIndex
|
||||||
|
|
||||||
|
seqRange := newRange()
|
||||||
|
|
||||||
|
for i := range c.cells {
|
||||||
|
cellSeq := c.cells[i].findSeq(seq)
|
||||||
|
if cellSeq != nil {
|
||||||
|
if cellSeq.pos >= beginIndex && cellSeq.pos < endIndex {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s seqCell) bool { return s.seq == seq })
|
||||||
|
} else {
|
||||||
|
if cellSeq.pos >= endIndex {
|
||||||
|
cellSeq.pos += offset
|
||||||
|
}
|
||||||
|
if i < seqRange.min {
|
||||||
|
seqRange.min = i
|
||||||
|
}
|
||||||
|
if i > seqRange.max {
|
||||||
|
seqRange.max = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if endIndex != c.seqNextPos[seq] {
|
||||||
|
err := c.shift(seq, endIndex, c.seqNextPos[seq], offset)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = seqRange
|
||||||
|
c.seqNextPos[seq] += offset
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
48
cache/tensor.go
vendored
Normal file
48
cache/tensor.go
vendored
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TensorCache struct {
|
||||||
|
curLayer int
|
||||||
|
|
||||||
|
cacheCtx ml.Context
|
||||||
|
keys, values []ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTensorCache(backend ml.Backend) *TensorCache {
|
||||||
|
return &TensorCache{
|
||||||
|
// TODO(jessegross): This context is not sized appropriately
|
||||||
|
cacheCtx: backend.NewContext(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TensorCache) Close() {
|
||||||
|
c.cacheCtx.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TensorCache) Sub(i int) *TensorCache {
|
||||||
|
if i >= len(c.keys) {
|
||||||
|
c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
|
||||||
|
c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.curLayer = i
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TensorCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TensorCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||||
|
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
|
||||||
|
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
|
||||||
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
|
||||||
|
}
|
@ -35,9 +35,9 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
"github.com/ollama/ollama/llama/runner"
|
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
|
"github.com/ollama/ollama/runner"
|
||||||
"github.com/ollama/ollama/server"
|
"github.com/ollama/ollama/server"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@ -338,7 +338,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
opts.MultiModal = len(info.ProjectorInfo) != 0
|
opts.MultiModal = true //len(info.ProjectorInfo) != 0
|
||||||
opts.ParentModel = info.Details.ParentModel
|
opts.ParentModel = info.Details.ParentModel
|
||||||
|
|
||||||
if interactive {
|
if interactive {
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llama/runner"
|
"github.com/ollama/ollama/runner"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -165,6 +165,8 @@ var (
|
|||||||
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
||||||
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
||||||
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
||||||
|
// Enable the new Ollama engine
|
||||||
|
NewRunners = Bool("OLLAMA_NEW_RUNNERS")
|
||||||
)
|
)
|
||||||
|
|
||||||
func String(s string) func() string {
|
func String(s string) func() string {
|
||||||
@ -250,6 +252,7 @@ func AsMap() map[string]EnvVar {
|
|||||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
|
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
|
||||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||||
|
"OLLAMA_NEW_RUNNERS": {"OLLAMA_NEW_RUNNERS", NewRunners(), "Enable the new Ollama engine"},
|
||||||
|
|
||||||
// Informational
|
// Informational
|
||||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||||
|
@ -252,6 +252,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
|||||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||||
}
|
}
|
||||||
finalParams := []string{"runner"}
|
finalParams := []string{"runner"}
|
||||||
|
if envconfig.NewRunners() {
|
||||||
|
finalParams = append(finalParams, "--new-runner")
|
||||||
|
}
|
||||||
finalParams = append(finalParams, params...)
|
finalParams = append(finalParams, params...)
|
||||||
finalParams = append(finalParams, "--port", strconv.Itoa(port))
|
finalParams = append(finalParams, "--port", strconv.Itoa(port))
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ func NewBackend(f *os.File) (Backend, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Context interface {
|
type Context interface {
|
||||||
Zeros(dtype DType, shape ...int) Tensor
|
Zeros(dtype DType, shape ...int64) Tensor
|
||||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ import (
|
|||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||||
)
|
)
|
||||||
|
|
||||||
type device struct {
|
type device struct {
|
||||||
@ -198,10 +198,9 @@ func (b *Backend) Get(name string) ml.Tensor {
|
|||||||
|
|
||||||
func (b *Backend) NewContext() ml.Context {
|
func (b *Backend) NewContext() ml.Context {
|
||||||
nodes := max(8192, len(b.meta.Tensors().Items())*5)
|
nodes := max(8192, len(b.meta.Tensors().Items())*5)
|
||||||
bts := make([]byte, C.size_t(nodes)*C.ggml_tensor_overhead()+C.ggml_graph_overhead_custom(C.size_t(nodes), false))
|
|
||||||
c := C.ggml_init(C.struct_ggml_init_params{
|
c := C.ggml_init(C.struct_ggml_init_params{
|
||||||
mem_buffer: unsafe.Pointer(&bts[0]),
|
mem_buffer: nil,
|
||||||
mem_size: C.size_t(len(bts)),
|
mem_size: C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false),
|
||||||
no_alloc: true,
|
no_alloc: true,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -244,17 +243,19 @@ func (c *Context) Forward(t ml.Tensor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Compute(t ml.Tensor) ml.Tensor {
|
func (c *Context) Compute(t ml.Tensor) ml.Tensor {
|
||||||
c.Forward(t)
|
|
||||||
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
|
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
|
||||||
|
|
||||||
|
if t != nil && C.ggml_nbytes(t.(*Tensor).t) != 0 {
|
||||||
backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
|
backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
|
||||||
|
|
||||||
t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
|
t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
|
||||||
C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||||
|
}
|
||||||
|
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
func (c Context) Zeros(dtype ml.DType, shape ...int64) ml.Tensor {
|
||||||
if len(shape) < 1 || len(shape) > 4 {
|
if len(shape) < 1 || len(shape) > 4 {
|
||||||
panic("unsupported number of dimensions")
|
panic("unsupported number of dimensions")
|
||||||
}
|
}
|
||||||
@ -283,6 +284,13 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
|||||||
|
|
||||||
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
||||||
n := len(s)
|
n := len(s)
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
shape := 0
|
||||||
|
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, (*C.int64_t)(unsafe.Pointer(&shape)))
|
||||||
|
return &Tensor{t: t}, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, v := range shape {
|
for _, v := range shape {
|
||||||
n /= v
|
n /= v
|
||||||
}
|
}
|
||||||
|
@ -1,160 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"image"
|
|
||||||
"io"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/cache"
|
|
||||||
"github.com/ollama/ollama/ml"
|
|
||||||
"github.com/ollama/ollama/model"
|
|
||||||
_ "github.com/ollama/ollama/model/llama"
|
|
||||||
_ "github.com/ollama/ollama/model/mllama"
|
|
||||||
"github.com/ollama/ollama/sample"
|
|
||||||
)
|
|
||||||
|
|
||||||
var args struct {
|
|
||||||
n int
|
|
||||||
debug bool
|
|
||||||
image string
|
|
||||||
cache bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func temp() error {
|
|
||||||
flag.IntVar(&args.n, "n", 10, "number of samples")
|
|
||||||
flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
|
|
||||||
flag.StringVar(&args.image, "image", "", "path to image file")
|
|
||||||
flag.BoolVar(&args.cache, "cache", false, "enable KV cache")
|
|
||||||
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
var prompt string
|
|
||||||
if n := len(flag.Args()); n == 1 {
|
|
||||||
bts, err := io.ReadAll(os.Stdin)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt = string(bts)
|
|
||||||
} else if n > 1 {
|
|
||||||
prompt = strings.Join(flag.Args()[1:], " ")
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("usage: %s path/to/file <prompt\n", filepath.Base(os.Args[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
level := slog.LevelInfo
|
|
||||||
if args.debug {
|
|
||||||
level = slog.LevelDebug
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
|
||||||
Level: level,
|
|
||||||
AddSource: true,
|
|
||||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
|
||||||
if attr.Key == slog.SourceKey {
|
|
||||||
source := attr.Value.Any().(*slog.Source)
|
|
||||||
source.File = filepath.Base(source.File)
|
|
||||||
}
|
|
||||||
|
|
||||||
return attr
|
|
||||||
},
|
|
||||||
})))
|
|
||||||
|
|
||||||
m, err := model.New(flag.Arg(0))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
inputIDs, err := m.(model.TextProcessor).Encode(prompt)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var opts []model.OptionsFunc
|
|
||||||
if args.cache {
|
|
||||||
opts = append(opts, model.WithCache(&cache.Simple{
|
|
||||||
Capacity: 2048,
|
|
||||||
DType: ml.DTypeF32,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.image != "" {
|
|
||||||
if err := func() error {
|
|
||||||
f, err := os.Open(args.image)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
img, _, err := image.Decode(f)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
opts = append(opts, model.WithImage(img))
|
|
||||||
return nil
|
|
||||||
}(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var offset int
|
|
||||||
for range args.n {
|
|
||||||
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
f32s := logit.Floats()
|
|
||||||
f64s := make([]float64, len(f32s))
|
|
||||||
for i, f32 := range f32s {
|
|
||||||
f64s[i] = float64(f32)
|
|
||||||
}
|
|
||||||
|
|
||||||
// do sampling
|
|
||||||
f64s, err = sample.Sample(f64s, sample.Greedy())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var outputIDs []int32
|
|
||||||
for _, f64 := range f64s {
|
|
||||||
if !m.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) {
|
|
||||||
outputIDs = append(outputIDs, int32(f64))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(outputIDs) == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := m.(model.TextProcessor).Decode(outputIDs)
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Print(s)
|
|
||||||
|
|
||||||
inputIDs = append(inputIDs, outputIDs...)
|
|
||||||
if args.cache {
|
|
||||||
offset = len(inputIDs) - 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
if err := temp(); err != nil {
|
|
||||||
fmt.Println("err", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
@ -3,6 +3,7 @@ package llama
|
|||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
@ -59,7 +60,7 @@ type SelfAttention struct {
|
|||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
k, v = cache.Put(ctx, k, v, cache.Options)
|
k, v, mask := cache.Put(ctx, k, v)
|
||||||
|
|
||||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
@ -82,6 +83,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
kq := k.Mulmat(ctx, q)
|
kq := k.Mulmat(ctx, q)
|
||||||
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||||
|
kq = kq.Add(ctx, mask)
|
||||||
kq = kq.Softmax(ctx)
|
kq = kq.Softmax(ctx)
|
||||||
|
|
||||||
kqv := v.Mulmat(ctx, kq)
|
kqv := v.Mulmat(ctx, kq)
|
||||||
@ -109,7 +111,7 @@ type Layer struct {
|
|||||||
MLP *MLP
|
MLP *MLP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
|
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor {
|
||||||
residual := hiddenState
|
residual := hiddenState
|
||||||
|
|
||||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
@ -142,7 +144,7 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
|||||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||||
|
|
||||||
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
|
outputs, err := ctx.FromIntSlice(opts.Outputs(), len(opts.Outputs()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package mllama
|
package mllama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
@ -16,6 +19,9 @@ type Model struct {
|
|||||||
|
|
||||||
ImageProcessor
|
ImageProcessor
|
||||||
TextProcessor
|
TextProcessor
|
||||||
|
|
||||||
|
start sync.Once
|
||||||
|
tCache *cache.TensorCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(c ml.Config) (model.Model, error) {
|
func New(c ml.Config) (model.Model, error) {
|
||||||
@ -28,6 +34,10 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||||
|
m.start.Do(func() {
|
||||||
|
m.tCache = cache.NewTensorCache(m.Backend())
|
||||||
|
})
|
||||||
|
|
||||||
var crossAttentionStates ml.Tensor
|
var crossAttentionStates ml.Tensor
|
||||||
if opts.Images != nil {
|
if opts.Images != nil {
|
||||||
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
|
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
|
||||||
@ -75,9 +85,9 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: attention mask, cross attention mask
|
// TODO: attention mask, cross attention mask
|
||||||
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache)
|
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache, m.tCache)
|
||||||
|
|
||||||
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
|
outputs, err := ctx.FromIntSlice(opts.Outputs(), len(opts.Outputs()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -4,9 +4,9 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TextSelfAttention struct {
|
type TextSelfAttention struct {
|
||||||
@ -16,7 +16,7 @@ type TextSelfAttention struct {
|
|||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache cache.Cache, opts *TextModelOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
|
|||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
key, value = cache.Put(ctx, key, value, cache.Options)
|
key, value, mask := cache.Put(ctx, key, value)
|
||||||
|
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
@ -39,11 +39,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
|
|||||||
|
|
||||||
scores := key.Mulmat(ctx, query)
|
scores := key.Mulmat(ctx, query)
|
||||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||||
|
|
||||||
if mask != nil {
|
|
||||||
scores = scores.Add(ctx, mask)
|
scores = scores.Add(ctx, mask)
|
||||||
}
|
|
||||||
|
|
||||||
scores = scores.Softmax(ctx)
|
scores = scores.Softmax(ctx)
|
||||||
|
|
||||||
attention := value.Mulmat(ctx, scores)
|
attention := value.Mulmat(ctx, scores)
|
||||||
@ -72,7 +68,7 @@ type TextSelfAttentionDecoderLayer struct {
|
|||||||
MLP *TextMLP
|
MLP *TextMLP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache cache.Cache, _ *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
|
||||||
residual := hiddenState
|
residual := hiddenState
|
||||||
|
|
||||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
@ -85,6 +81,10 @@ func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, pos
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *TextSelfAttentionDecoderLayer) Run() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
type TextCrossAttention struct {
|
type TextCrossAttention struct {
|
||||||
QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
|
QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
|
||||||
Query *nn.Linear `gguf:"cross_attn_q_proj"`
|
Query *nn.Linear `gguf:"cross_attn_q_proj"`
|
||||||
@ -94,23 +94,29 @@ type TextCrossAttention struct {
|
|||||||
Output *nn.Linear `gguf:"cross_attn_o_proj"`
|
Output *nn.Linear `gguf:"cross_attn_o_proj"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, _ cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
|
||||||
|
|
||||||
query := ca.Query.Forward(ctx, hiddenState)
|
query := ca.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
|
|
||||||
key := ca.Key.Forward(ctx, crossAttentionStates)
|
var key, value ml.Tensor
|
||||||
|
if crossAttentionStates != nil {
|
||||||
|
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||||
|
|
||||||
|
key = ca.Key.Forward(ctx, crossAttentionStates)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||||
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
|
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
|
|
||||||
value := ca.Value.Forward(ctx, crossAttentionStates)
|
value = ca.Value.Forward(ctx, crossAttentionStates)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||||
|
|
||||||
// TODO cache key, value
|
tCache.Put(ctx, key, value)
|
||||||
|
} else {
|
||||||
|
key, value, _ = tCache.Get(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
@ -135,13 +141,17 @@ type TextCrossAttentionDecoderLayer struct {
|
|||||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||||
MLP *TextMLP
|
MLP *TextMLP
|
||||||
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
|
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
|
||||||
|
|
||||||
|
run bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
|
||||||
|
d.run = true
|
||||||
|
|
||||||
residual := hiddenState
|
residual := hiddenState
|
||||||
|
|
||||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
|
hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, tCache, opts)
|
||||||
hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
|
hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
|
||||||
hiddenState = hiddenState.Add(ctx, residual)
|
hiddenState = hiddenState.Add(ctx, residual)
|
||||||
residual = hiddenState
|
residual = hiddenState
|
||||||
@ -152,18 +162,23 @@ func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *TextCrossAttentionDecoderLayer) Run() bool {
|
||||||
|
return d.run
|
||||||
|
}
|
||||||
|
|
||||||
type TextDecoderLayer interface {
|
type TextDecoderLayer interface {
|
||||||
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor
|
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor
|
||||||
|
Run() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextDecoder struct {
|
type TextDecoder struct {
|
||||||
Layers []TextDecoderLayer
|
Layers []TextDecoderLayer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
|
||||||
for i, layer := range d.Layers {
|
for i, layer := range d.Layers {
|
||||||
if !slices.Contains(opts.crossAttentionLayers, uint32(i)) || crossAttentionStates != nil {
|
if layer.Run() || crossAttentionStates != nil {
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), opts)
|
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), tCache.Sub(i), opts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,9 +204,9 @@ type TextModel struct {
|
|||||||
*TextModelOptions
|
*TextModelOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache) ml.Tensor {
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
|
||||||
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, tCache, m.TextModelOptions)
|
||||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
return m.Output.Forward(ctx, hiddenState)
|
return m.Output.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
@ -20,51 +20,28 @@ import (
|
|||||||
_ "github.com/ollama/ollama/ml/backend"
|
_ "github.com/ollama/ollama/ml/backend"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Cache struct {
|
|
||||||
cache.Cache
|
|
||||||
cache.Options
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Cache) Sub(i int) Cache {
|
|
||||||
if c.Cache != nil {
|
|
||||||
return Cache{
|
|
||||||
Cache: c.Cache.Sub(i),
|
|
||||||
Options: c.Options,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Cache) Put(ctx ml.Context, key, value ml.Tensor, opts cache.Options) (ml.Tensor, ml.Tensor) {
|
|
||||||
if c.Cache != nil {
|
|
||||||
return c.Cache.Put(ctx, key, value, opts)
|
|
||||||
}
|
|
||||||
|
|
||||||
return key, value
|
|
||||||
}
|
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
inputs []int32
|
inputs []int32
|
||||||
|
positions []int32
|
||||||
|
outputs []int32
|
||||||
|
|
||||||
Offset int
|
sequences []int
|
||||||
|
|
||||||
Images []image.Image
|
Images []image.Image
|
||||||
|
|
||||||
Cache
|
cache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opts Options) Inputs() []int32 {
|
func (opts Options) Inputs() []int32 {
|
||||||
return opts.inputs[opts.Offset:]
|
return opts.inputs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opts Options) Positions() []int32 {
|
func (opts Options) Positions() []int32 {
|
||||||
positions := make([]int32, len(opts.inputs)-opts.Offset)
|
return opts.positions
|
||||||
for i := range positions {
|
}
|
||||||
positions[i] = int32(opts.Offset + i)
|
|
||||||
}
|
|
||||||
|
|
||||||
return positions
|
func (opts Options) Outputs() []int32 {
|
||||||
|
return opts.outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
type OptionsFunc func(Model, *Options)
|
type OptionsFunc func(Model, *Options)
|
||||||
@ -75,10 +52,21 @@ func WithInputIDs(ids []int32) OptionsFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithOffset(offset int) OptionsFunc {
|
func WithPositions(pos []int32) OptionsFunc {
|
||||||
return func(m Model, opts *Options) {
|
return func(m Model, opts *Options) {
|
||||||
opts.Offset = offset
|
opts.positions = pos
|
||||||
opts.Cache.Position = offset
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithOutputs(outputs []int32) OptionsFunc {
|
||||||
|
return func(m Model, opts *Options) {
|
||||||
|
opts.outputs = outputs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithSequences(seqs []int) OptionsFunc {
|
||||||
|
return func(m Model, opts *Options) {
|
||||||
|
opts.sequences = seqs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,12 +78,7 @@ func WithImage(img image.Image) OptionsFunc {
|
|||||||
|
|
||||||
func WithCache(c cache.Cache) OptionsFunc {
|
func WithCache(c cache.Cache) OptionsFunc {
|
||||||
return func(m Model, opts *Options) {
|
return func(m Model, opts *Options) {
|
||||||
opts.Cache = Cache{
|
opts.Cache = c
|
||||||
Cache: c,
|
|
||||||
Options: cache.Options{
|
|
||||||
Position: opts.Offset,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,18 +255,22 @@ func canNil(t reflect.Type) bool {
|
|||||||
t.Kind() == reflect.Slice
|
t.Kind() == reflect.Slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) {
|
func Forward(ctx ml.Context, m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) {
|
||||||
var opts Options
|
var opts Options
|
||||||
for _, optsFunc := range optsFuncs {
|
for _, optsFunc := range optsFuncs {
|
||||||
optsFunc(m, &opts)
|
optsFunc(m, &opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := m.Backend().NewContext()
|
err := opts.Cache.StartForward(ctx, opts.sequences)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
t, err := m.Forward(ctx, opts)
|
t, err := m.Forward(ctx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer ctx.Close()
|
|
||||||
|
|
||||||
|
ctx.Forward(t)
|
||||||
return ctx.Compute(t), nil
|
return ctx.Compute(t), nil
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
package runner
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func findStop(sequence string, stops []string) (bool, string) {
|
func FindStop(sequence string, stops []string) (bool, string) {
|
||||||
for _, stop := range stops {
|
for _, stop := range stops {
|
||||||
if strings.Contains(sequence, stop) {
|
if strings.Contains(sequence, stop) {
|
||||||
return true, stop
|
return true, stop
|
||||||
@ -14,7 +14,7 @@ func findStop(sequence string, stops []string) (bool, string) {
|
|||||||
return false, ""
|
return false, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func containsStopSuffix(sequence string, stops []string) bool {
|
func ContainsStopSuffix(sequence string, stops []string) bool {
|
||||||
for _, stop := range stops {
|
for _, stop := range stops {
|
||||||
for i := 1; i <= len(stop); i++ {
|
for i := 1; i <= len(stop); i++ {
|
||||||
if strings.HasSuffix(sequence, stop[:i]) {
|
if strings.HasSuffix(sequence, stop[:i]) {
|
||||||
@ -29,7 +29,7 @@ func containsStopSuffix(sequence string, stops []string) bool {
|
|||||||
// truncateStop removes the provided stop string from pieces,
|
// truncateStop removes the provided stop string from pieces,
|
||||||
// returning the partial pieces with stop removed, including truncating
|
// returning the partial pieces with stop removed, including truncating
|
||||||
// the last piece if required (and signalling if this was the case)
|
// the last piece if required (and signalling if this was the case)
|
||||||
func truncateStop(pieces []string, stop string) ([]string, bool) {
|
func TruncateStop(pieces []string, stop string) ([]string, bool) {
|
||||||
joined := strings.Join(pieces, "")
|
joined := strings.Join(pieces, "")
|
||||||
|
|
||||||
index := strings.Index(joined, stop)
|
index := strings.Index(joined, stop)
|
||||||
@ -65,7 +65,7 @@ func truncateStop(pieces []string, stop string) ([]string, bool) {
|
|||||||
return result, tokenTruncated
|
return result, tokenTruncated
|
||||||
}
|
}
|
||||||
|
|
||||||
func incompleteUnicode(token string) bool {
|
func IncompleteUnicode(token string) bool {
|
||||||
incomplete := false
|
incomplete := false
|
||||||
|
|
||||||
// check if there is incomplete UTF-8 character at the end
|
// check if there is incomplete UTF-8 character at the end
|
@ -1,4 +1,4 @@
|
|||||||
package runner
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -52,7 +52,7 @@ func TestTruncateStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result, resultTrunc := truncateStop(tt.pieces, tt.stop)
|
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
|
||||||
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
||||||
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
||||||
}
|
}
|
||||||
@ -120,7 +120,7 @@ func TestIncompleteUnicode(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := incompleteUnicode(tt.input)
|
result := IncompleteUnicode(tt.input)
|
||||||
if result != tt.expected {
|
if result != tt.expected {
|
||||||
t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected)
|
t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected)
|
||||||
}
|
}
|
248
runner/newrunner/cache.go
Normal file
248
runner/newrunner/cache.go
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
package newrunner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"math"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InputCache struct {
|
||||||
|
// context window size (per slot)
|
||||||
|
numCtx int
|
||||||
|
|
||||||
|
// individual KV caches
|
||||||
|
slots []InputCacheSlot
|
||||||
|
|
||||||
|
// optimize cache eviction for multiple users
|
||||||
|
multiUserCache bool
|
||||||
|
|
||||||
|
cache cache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInputCache(backend ml.Backend, kvSize int, numSlots int, multiUserCache bool) (*InputCache, error) {
|
||||||
|
if kvSize/numSlots < 1 {
|
||||||
|
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
||||||
|
}
|
||||||
|
|
||||||
|
slots := make([]InputCacheSlot, numSlots)
|
||||||
|
|
||||||
|
for i := range slots {
|
||||||
|
slots[i] = InputCacheSlot{
|
||||||
|
Id: i,
|
||||||
|
Inputs: make([]input, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &InputCache{
|
||||||
|
numCtx: kvSize / numSlots,
|
||||||
|
slots: slots,
|
||||||
|
multiUserCache: multiUserCache,
|
||||||
|
cache: cache.NewCausalCache(backend, kvSize, ml.DTypeF32),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Locking: Operations on InputCacheSlot (including finding one
|
||||||
|
// through LoadCacheSlot) require a lock to be be held that serializes
|
||||||
|
// these operations with each other and llama.Decode
|
||||||
|
|
||||||
|
type InputCacheSlot struct {
|
||||||
|
// Index in the KV cache
|
||||||
|
Id int
|
||||||
|
|
||||||
|
// Inputs that are stored in the KV cache
|
||||||
|
Inputs []input
|
||||||
|
|
||||||
|
// is this cache actively being processed as part of a sequence?
|
||||||
|
InUse bool
|
||||||
|
|
||||||
|
// last time this cache was used (as of start of processing)
|
||||||
|
lastUsed time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
|
||||||
|
var slot *InputCacheSlot
|
||||||
|
var numPast int
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// In single-user scenarios, the longest cache slot works fine for getting good input
|
||||||
|
// cache hit rates and it reuses the same VRAM over and over again, which is good for
|
||||||
|
// GPU performance in situations where we miss the input cache.
|
||||||
|
// For multiple users, the "best" cache slot produces better input cache hit rates
|
||||||
|
// at the cost of worse performance when we miss the input cache (because it causes
|
||||||
|
// GPU L2 cache misses due to spreading out accesses across VRAM).
|
||||||
|
if !c.multiUserCache {
|
||||||
|
slot, numPast, err = c.findLongestCacheSlot(prompt)
|
||||||
|
} else {
|
||||||
|
slot, numPast, err = c.findBestCacheSlot(prompt)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cachePrompt {
|
||||||
|
numPast = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
slot.InUse = true
|
||||||
|
slot.lastUsed = time.Now()
|
||||||
|
|
||||||
|
if numPast == len(prompt) {
|
||||||
|
// Leave one input to sample so we can get a response
|
||||||
|
numPast--
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.cache.Remove(slot.Id, numPast, math.MaxInt)
|
||||||
|
if err != nil {
|
||||||
|
// Some models don't support partial erasure
|
||||||
|
c.cache.Remove(slot.Id, 0, math.MaxInt)
|
||||||
|
numPast = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
|
||||||
|
"used", numPast, "remaining", len(prompt)-numPast)
|
||||||
|
|
||||||
|
prompt = prompt[numPast:]
|
||||||
|
slot.Inputs = slot.Inputs[:numPast]
|
||||||
|
|
||||||
|
return slot, prompt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
|
||||||
|
longest := -1
|
||||||
|
var longestSlot *InputCacheSlot
|
||||||
|
|
||||||
|
for i, s := range c.slots {
|
||||||
|
if s.InUse {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
count := countCommonPrefix(s.Inputs, prompt)
|
||||||
|
if count > longest {
|
||||||
|
longest = count
|
||||||
|
longestSlot = &c.slots[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if longestSlot == nil {
|
||||||
|
return nil, 0, errors.New("no available cache slots")
|
||||||
|
}
|
||||||
|
|
||||||
|
return longestSlot, longest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
|
||||||
|
oldest := time.Now()
|
||||||
|
var oldestSlot *InputCacheSlot
|
||||||
|
|
||||||
|
longest := -1
|
||||||
|
var longestSlot *InputCacheSlot
|
||||||
|
|
||||||
|
for i, s := range c.slots {
|
||||||
|
count := countCommonPrefix(s.Inputs, prompt)
|
||||||
|
if count > longest {
|
||||||
|
longest = count
|
||||||
|
longestSlot = &c.slots[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
|
||||||
|
oldest = s.lastUsed
|
||||||
|
oldestSlot = &c.slots[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if longest == len(longestSlot.Inputs) && !longestSlot.InUse {
|
||||||
|
return longestSlot, longest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldestSlot.InUse {
|
||||||
|
return nil, 0, errors.New("no available cache slots")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(oldestSlot.Inputs) != 0 {
|
||||||
|
slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
|
||||||
|
"used", oldestSlot.lastUsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
if longest > 0 && longestSlot != oldestSlot {
|
||||||
|
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||||
|
len(longestSlot.Inputs))
|
||||||
|
oldestSlot.Inputs = make([]input, longest)
|
||||||
|
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||||
|
// This is only nil for unit tests
|
||||||
|
if c.cache != nil {
|
||||||
|
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return oldestSlot, longest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func countCommonPrefix(a []input, b []input) int {
|
||||||
|
var count int
|
||||||
|
|
||||||
|
for i := range a {
|
||||||
|
if i >= len(b) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(a[i], b[i]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
||||||
|
targetFree := (c.numCtx - numKeep) / 2
|
||||||
|
targetFree = max(targetFree, 1)
|
||||||
|
|
||||||
|
currentFree := c.numCtx - inputLen
|
||||||
|
discard := targetFree - currentFree
|
||||||
|
|
||||||
|
if discard < 0 {
|
||||||
|
discard = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return discard
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||||
|
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||||
|
//
|
||||||
|
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
||||||
|
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||||
|
if numKeep >= c.numCtx {
|
||||||
|
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
||||||
|
|
||||||
|
if discard <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||||
|
"keep", numKeep, "discard", discard)
|
||||||
|
|
||||||
|
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||||
|
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
||||||
|
slot.Inputs[i-discard] = slot.Inputs[i]
|
||||||
|
}
|
||||||
|
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package runner
|
package newrunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
971
runner/newrunner/runner.go
Normal file
971
runner/newrunner/runner.go
Normal file
@ -0,0 +1,971 @@
|
|||||||
|
package newrunner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/runner/common"
|
||||||
|
"github.com/ollama/ollama/sample"
|
||||||
|
|
||||||
|
_ "github.com/ollama/ollama/model/llama"
|
||||||
|
_ "github.com/ollama/ollama/model/mllama"
|
||||||
|
)
|
||||||
|
|
||||||
|
// input is an element of the prompt to process, either
|
||||||
|
// a token or an image embedding (generated from a vision projector)
|
||||||
|
type input struct {
|
||||||
|
token int32
|
||||||
|
|
||||||
|
// embed is an image embedding
|
||||||
|
//embed []float32
|
||||||
|
|
||||||
|
image image.Image
|
||||||
|
}
|
||||||
|
|
||||||
|
type Sequence struct {
|
||||||
|
// batch index
|
||||||
|
iBatch int
|
||||||
|
|
||||||
|
// prompt inputs left to evaluate
|
||||||
|
inputs []input
|
||||||
|
|
||||||
|
// inputs that have been added to a batch but not yet submitted to Decode
|
||||||
|
pendingInputs []input
|
||||||
|
|
||||||
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
|
pendingResponses []string
|
||||||
|
|
||||||
|
// input cache being used by this sequence
|
||||||
|
cache *InputCacheSlot
|
||||||
|
|
||||||
|
// channel to send responses over
|
||||||
|
responses chan string
|
||||||
|
|
||||||
|
// channel to stop decoding (such as if the remote connection is closed)
|
||||||
|
quit chan bool
|
||||||
|
|
||||||
|
// number of tokens to predict
|
||||||
|
numPredict int
|
||||||
|
|
||||||
|
// set of samplers to run on generated logits
|
||||||
|
samplers []sample.Sampler
|
||||||
|
|
||||||
|
// channel to send back the embedding if embedding only
|
||||||
|
embedding chan []float32
|
||||||
|
|
||||||
|
// stop sequences
|
||||||
|
stop []string
|
||||||
|
|
||||||
|
// number of inputs to keep at the beginning when shifting context window
|
||||||
|
numKeep int
|
||||||
|
|
||||||
|
// true if an embedding are to be returned instead of text generation
|
||||||
|
embeddingOnly bool
|
||||||
|
|
||||||
|
doneReason string
|
||||||
|
|
||||||
|
// Metrics
|
||||||
|
startProcessingTime time.Time
|
||||||
|
startGenerationTime time.Time
|
||||||
|
numPredicted int
|
||||||
|
numPromptInputs int
|
||||||
|
}
|
||||||
|
|
||||||
|
type NewSequenceParams struct {
|
||||||
|
numPredict int
|
||||||
|
stop []string
|
||||||
|
numKeep int
|
||||||
|
samplers []sample.Sampler
|
||||||
|
embedding bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||||
|
s.ready.Wait()
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
inputs, err := s.inputs(prompt, images)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||||
|
} else if len(inputs) == 0 {
|
||||||
|
return nil, errors.New("no input provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.numKeep < 0 {
|
||||||
|
params.numKeep = len(inputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that at least 1 input can be discarded during shift
|
||||||
|
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||||
|
|
||||||
|
if len(inputs) > s.cache.numCtx {
|
||||||
|
discard := len(inputs) - s.cache.numCtx
|
||||||
|
newInputs := inputs[:params.numKeep]
|
||||||
|
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||||
|
|
||||||
|
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
||||||
|
inputs = newInputs
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(jessegross): Ingest cached history for grammar
|
||||||
|
|
||||||
|
return &Sequence{
|
||||||
|
inputs: inputs,
|
||||||
|
numPromptInputs: len(inputs),
|
||||||
|
startProcessingTime: startTime,
|
||||||
|
numPredict: params.numPredict,
|
||||||
|
pendingResponses: make([]string, 0),
|
||||||
|
responses: make(chan string, 100),
|
||||||
|
quit: make(chan bool, 1),
|
||||||
|
embedding: make(chan []float32, 1),
|
||||||
|
samplers: params.samplers,
|
||||||
|
embeddingOnly: params.embedding,
|
||||||
|
stop: params.stop,
|
||||||
|
numKeep: params.numKeep,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// inputs processes the prompt and images into a list of inputs
|
||||||
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||||
|
// generating image embeddings for each image
|
||||||
|
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||||
|
var inputs []input
|
||||||
|
var parts []string
|
||||||
|
var matches [][]string
|
||||||
|
|
||||||
|
//if s.image != nil {
|
||||||
|
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||||
|
parts = re.Split(prompt, -1)
|
||||||
|
matches = re.FindAllStringSubmatch(prompt, -1)
|
||||||
|
/*} else {
|
||||||
|
parts = []string{prompt}
|
||||||
|
}*/
|
||||||
|
|
||||||
|
for i, part := range parts {
|
||||||
|
// text - tokenize
|
||||||
|
tokens, err := s.model.(model.TextProcessor).Encode(part)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range tokens {
|
||||||
|
inputs = append(inputs, input{token: t})
|
||||||
|
}
|
||||||
|
|
||||||
|
// image - generate image embedding
|
||||||
|
if i < len(matches) {
|
||||||
|
n, _ := strconv.Atoi(matches[i][1])
|
||||||
|
|
||||||
|
imageIndex := -1
|
||||||
|
for j := range images {
|
||||||
|
if images[j].ID == n {
|
||||||
|
imageIndex = j
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageIndex < 0 {
|
||||||
|
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs = append(inputs, input{image: image})
|
||||||
|
|
||||||
|
/*embed, err := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range embed {
|
||||||
|
inputs = append(inputs, input{embed: e})
|
||||||
|
}*/
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
// is the server ready to process requests?
|
||||||
|
// protects access to model and image
|
||||||
|
ready sync.WaitGroup
|
||||||
|
|
||||||
|
// loaded model
|
||||||
|
model model.Model
|
||||||
|
|
||||||
|
// status for external health reporting - loading, ready to serve, etc.
|
||||||
|
status ServerStatus
|
||||||
|
|
||||||
|
// current progress on loading the model
|
||||||
|
progress float32
|
||||||
|
|
||||||
|
// number of simultaneous requests to handle
|
||||||
|
parallel int
|
||||||
|
|
||||||
|
// maximum number of elements in a batch (per sequence)
|
||||||
|
// TODO (jmorganca): make this n_batch
|
||||||
|
batchSize int
|
||||||
|
|
||||||
|
// protects access to everything below this line
|
||||||
|
// this is context state needed for decoding
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// indicates that data is ready for processing
|
||||||
|
cond *sync.Cond
|
||||||
|
|
||||||
|
// the list of simultaneous sequences being evaluated
|
||||||
|
seqs []*Sequence
|
||||||
|
|
||||||
|
// seqs can have a maximum of parallel entries, which
|
||||||
|
// is enfoced by seqSem
|
||||||
|
seqsSem *semaphore.Weighted
|
||||||
|
|
||||||
|
// KV cache
|
||||||
|
cache *InputCache
|
||||||
|
|
||||||
|
// next sequence for prompt processing to avoid starvation
|
||||||
|
// TODO(jessegross): Currently unused
|
||||||
|
nextSeq int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) allNil() bool {
|
||||||
|
for _, item := range s.seqs {
|
||||||
|
if item != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func flushPending(seq *Sequence) bool {
|
||||||
|
joined := strings.Join(seq.pendingResponses, "")
|
||||||
|
seq.pendingResponses = []string{}
|
||||||
|
|
||||||
|
// Check if there are any partial UTF-8 characters remaining.
|
||||||
|
// We already check and queue as we are generating but some may
|
||||||
|
// still make it here:
|
||||||
|
// - Sequence is ending, e.g. generation limit has been hit
|
||||||
|
// - Invalid characters in the middle of a string
|
||||||
|
// This is a stricter check to ensure we never output invalid Unicode.
|
||||||
|
for !utf8.ValidString(joined) {
|
||||||
|
joined = joined[:len(joined)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(joined) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case seq.responses <- joined:
|
||||||
|
return true
|
||||||
|
case <-seq.quit:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||||
|
seq := s.seqs[seqIndex]
|
||||||
|
|
||||||
|
flushPending(seq)
|
||||||
|
seq.doneReason = reason
|
||||||
|
close(seq.responses)
|
||||||
|
close(seq.embedding)
|
||||||
|
seq.cache.InUse = false
|
||||||
|
s.seqs[seqIndex] = nil
|
||||||
|
s.seqsSem.Release(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) run(ctx context.Context) {
|
||||||
|
s.ready.Wait()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
err := s.processBatch()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) processBatch() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
for s.allNil() {
|
||||||
|
s.cond.Wait() // Wait until an item is added
|
||||||
|
}
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var inputIDs []int32
|
||||||
|
var pos []int32
|
||||||
|
var outputs []int32
|
||||||
|
var seqs []int
|
||||||
|
|
||||||
|
var image image.Image
|
||||||
|
|
||||||
|
for i, seq := range s.seqs {
|
||||||
|
if seq == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// if past the num predict limit
|
||||||
|
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||||
|
s.removeSequence(i, "limit")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for j, input := range seq.inputs {
|
||||||
|
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
|
||||||
|
if len(seq.pendingInputs) == 0 {
|
||||||
|
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if j >= s.batchSize {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.image != nil {
|
||||||
|
if image != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
image = input.image
|
||||||
|
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
inputIDs = append(inputIDs, input.token)
|
||||||
|
pos = append(pos, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
|
seqs = append(seqs, seq.cache.Id)
|
||||||
|
|
||||||
|
seq.iBatch = len(outputs)
|
||||||
|
if j+1 == len(seq.inputs) {
|
||||||
|
outputs = append(outputs, int32(len(inputIDs)-1))
|
||||||
|
}
|
||||||
|
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||||
|
}
|
||||||
|
|
||||||
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inputIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var options []model.OptionsFunc
|
||||||
|
if image != nil {
|
||||||
|
options = append(options, model.WithImage(image))
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := s.model.Backend().NewContext()
|
||||||
|
defer ctx.Close()
|
||||||
|
|
||||||
|
logit, err := model.Forward(ctx, s.model, append(options, model.WithCache(s.cache.cache), model.WithInputIDs(inputIDs), model.WithPositions(pos), model.WithOutputs(outputs), model.WithSequences(seqs))...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := logit.Floats()
|
||||||
|
|
||||||
|
for i, seq := range s.seqs {
|
||||||
|
if seq == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// After calling Forward, pending inputs are now in the cache
|
||||||
|
if len(seq.pendingInputs) > 0 {
|
||||||
|
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||||
|
seq.pendingInputs = []input{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't sample prompt processing
|
||||||
|
if len(seq.inputs) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
seq.numPredicted++
|
||||||
|
if seq.numPredicted == 1 {
|
||||||
|
seq.startGenerationTime = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// if done processing the prompt, generate an embedding and return
|
||||||
|
if seq.embeddingOnly {
|
||||||
|
/*embed := s.lc.GetEmbeddingsSeq(seq.cache.Id)
|
||||||
|
if embed == nil {
|
||||||
|
embed = s.lc.GetEmbeddingsIth(seq.iBatch)
|
||||||
|
}
|
||||||
|
|
||||||
|
seq.embedding <- embed*/
|
||||||
|
s.removeSequence(i, "")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
vocabSize := len(f32s) / len(outputs)
|
||||||
|
seqLogits := f32s[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]
|
||||||
|
|
||||||
|
// TODO(jessegross): The data type and number of outputs for the samplers seem inconsistent
|
||||||
|
f64s := make([]float64, vocabSize)
|
||||||
|
for j, f32 := range seqLogits {
|
||||||
|
f64s[j] = float64(f32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// do sampling
|
||||||
|
f64s, err = sample.Sample(f64s, seq.samplers...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputIDs []int32
|
||||||
|
for _, f64 := range f64s {
|
||||||
|
if !s.model.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) {
|
||||||
|
outputIDs = append(outputIDs, int32(f64))
|
||||||
|
} else {
|
||||||
|
s.removeSequence(i, "stop")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(outputIDs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
piece, err := s.model.(model.TextProcessor).Decode(outputIDs)
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range outputIDs {
|
||||||
|
seq.inputs = append(seq.inputs, input{token: id})
|
||||||
|
}
|
||||||
|
|
||||||
|
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||||
|
sequence := strings.Join(seq.pendingResponses, "")
|
||||||
|
|
||||||
|
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||||
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||||
|
|
||||||
|
var tokenTruncated bool
|
||||||
|
origLen := len(seq.pendingResponses)
|
||||||
|
seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
|
||||||
|
newLen := len(seq.pendingResponses)
|
||||||
|
|
||||||
|
// Update the cache based on the tokens that will be returned:
|
||||||
|
// - We have more tokens than are currently in the cache because
|
||||||
|
// the last ones generated weren't submitted to Forward
|
||||||
|
// - Remove any stop sequences that we stripped out
|
||||||
|
// - If truncateStop removed a portion of a token, drop that
|
||||||
|
// - As defense-in-depth, if truncatedToken didn't find a stop token
|
||||||
|
// remove the extra ones that we added to the cache len
|
||||||
|
tokenLen := len(seq.cache.Inputs) + len(outputIDs)
|
||||||
|
tokenLen -= origLen - newLen
|
||||||
|
if tokenTruncated {
|
||||||
|
tokenLen--
|
||||||
|
}
|
||||||
|
if origLen == newLen {
|
||||||
|
tokenLen = len(seq.cache.Inputs)
|
||||||
|
}
|
||||||
|
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||||
|
|
||||||
|
s.removeSequence(i, "stop")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.ContainsStopSuffix(sequence, seq.stop) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.IncompleteUnicode(sequence) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !flushPending(seq) {
|
||||||
|
s.removeSequence(i, "connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||||
|
// this way the api acts as a proxy instead of using a different api for the
|
||||||
|
// runner
|
||||||
|
type Options struct {
|
||||||
|
api.Runner
|
||||||
|
|
||||||
|
NumKeep int `json:"n_keep"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
NumPredict int `json:"n_predict"`
|
||||||
|
TopK int `json:"top_k"`
|
||||||
|
TopP float32 `json:"top_p"`
|
||||||
|
MinP float32 `json:"min_p"`
|
||||||
|
TypicalP float32 `json:"typical_p"`
|
||||||
|
RepeatLastN int `json:"repeat_last_n"`
|
||||||
|
Temperature float32 `json:"temperature"`
|
||||||
|
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||||
|
PresencePenalty float32 `json:"presence_penalty"`
|
||||||
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||||
|
Mirostat int `json:"mirostat"`
|
||||||
|
MirostatTau float32 `json:"mirostat_tau"`
|
||||||
|
MirostatEta float32 `json:"mirostat_eta"`
|
||||||
|
Stop []string `json:"stop"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageData struct {
|
||||||
|
Data []byte `json:"data"`
|
||||||
|
ID int `json:"id"`
|
||||||
|
AspectRatioID int `json:"aspect_ratio_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Images []ImageData `json:"image_data"`
|
||||||
|
Grammar string `json:"grammar"`
|
||||||
|
CachePrompt bool `json:"cache_prompt"`
|
||||||
|
|
||||||
|
Options
|
||||||
|
}
|
||||||
|
|
||||||
|
type Timings struct {
|
||||||
|
PredictedN int `json:"predicted_n"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms"`
|
||||||
|
PromptN int `json:"prompt_n"`
|
||||||
|
PromptMS float64 `json:"prompt_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Stop bool `json:"stop"`
|
||||||
|
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||||
|
PredictedN int `json:"predicted_n,omitempty"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||||
|
PromptN int `json:"prompt_n,omitempty"`
|
||||||
|
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||||
|
|
||||||
|
Timings Timings `json:"timings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSamplers(req CompletionRequest) []sample.Sampler {
|
||||||
|
/*var samplingParams llama.SamplingParams
|
||||||
|
samplingParams.TopK = req.TopK
|
||||||
|
samplingParams.TopP = req.TopP
|
||||||
|
samplingParams.MinP = req.MinP
|
||||||
|
samplingParams.TypicalP = req.TypicalP
|
||||||
|
samplingParams.Temp = req.Temperature
|
||||||
|
samplingParams.RepeatLastN = req.RepeatLastN
|
||||||
|
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
||||||
|
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
||||||
|
samplingParams.PenaltyPresent = req.PresencePenalty
|
||||||
|
samplingParams.Mirostat = req.Mirostat
|
||||||
|
samplingParams.MirostatTau = req.MirostatTau
|
||||||
|
samplingParams.MirostatEta = req.MirostatEta
|
||||||
|
samplingParams.Seed = uint32(req.Seed)
|
||||||
|
samplingParams.Grammar = req.Grammar*/
|
||||||
|
|
||||||
|
return []sample.Sampler{sample.Greedy()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req CompletionRequest
|
||||||
|
req.Options = Options(api.DefaultOptions())
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the headers to indicate streaming
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
|
numPredict: req.NumPredict,
|
||||||
|
stop: req.Stop,
|
||||||
|
numKeep: req.NumKeep,
|
||||||
|
samplers: getSamplers(req),
|
||||||
|
embedding: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||||
|
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
slog.Info("aborting completion request due to client closing the connection")
|
||||||
|
} else {
|
||||||
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
found := false
|
||||||
|
for i, sq := range s.seqs {
|
||||||
|
if sq == nil {
|
||||||
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||||
|
if err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.seqs[i] = seq
|
||||||
|
s.cond.Signal()
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
close(seq.quit)
|
||||||
|
return
|
||||||
|
case content, ok := <-seq.responses:
|
||||||
|
if ok {
|
||||||
|
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||||
|
Content: content,
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
close(seq.quit)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
flusher.Flush()
|
||||||
|
} else {
|
||||||
|
// Send the final response
|
||||||
|
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||||
|
Stop: true,
|
||||||
|
StoppedLimit: seq.doneReason == "limit",
|
||||||
|
Timings: Timings{
|
||||||
|
PromptN: seq.numPromptInputs,
|
||||||
|
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
||||||
|
PredictedN: seq.numPredicted,
|
||||||
|
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
CachePrompt bool `json:"cache_prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req EmbeddingRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
slog.Debug("embedding request", "content", req.Content)
|
||||||
|
|
||||||
|
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||||
|
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
slog.Info("aborting embeddings request due to client closing the connection")
|
||||||
|
} else {
|
||||||
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
found := false
|
||||||
|
for i, sq := range s.seqs {
|
||||||
|
if sq == nil {
|
||||||
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||||
|
if err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.seqs[i] = seq
|
||||||
|
s.cond.Signal()
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding := <-seq.embedding
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||||
|
Embedding: embedding,
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type HealthResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Progress float32 `json:"progress"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStatus int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ServerStatusReady ServerStatus = iota
|
||||||
|
ServerStatusLoadingModel
|
||||||
|
ServerStatusError
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s ServerStatus) ToString() string {
|
||||||
|
switch s {
|
||||||
|
case ServerStatusReady:
|
||||||
|
return "ok"
|
||||||
|
case ServerStatusLoadingModel:
|
||||||
|
return "loading model"
|
||||||
|
default:
|
||||||
|
return "server error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||||
|
Status: s.status.ToString(),
|
||||||
|
Progress: s.progress,
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type multiLPath []string
|
||||||
|
|
||||||
|
func (m *multiLPath) Set(value string) error {
|
||||||
|
*m = append(*m, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *multiLPath) String() string {
|
||||||
|
return strings.Join(*m, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) loadModel(
|
||||||
|
//params llama.ModelParams,
|
||||||
|
mpath string,
|
||||||
|
//lpath multiLPath,
|
||||||
|
kvSize int,
|
||||||
|
/*kvCacheType string,
|
||||||
|
flashAttention bool,*/
|
||||||
|
_ int,
|
||||||
|
multiUserCache bool,
|
||||||
|
) {
|
||||||
|
var err error
|
||||||
|
s.model, err = model.New(mpath)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
|
||||||
|
s.lc, err = llama.NewContextWithModel(s.oldModel, ctxParams)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if lpath.String() != "" {
|
||||||
|
for _, path := range lpath {
|
||||||
|
err := s.oldModel.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}*/
|
||||||
|
|
||||||
|
s.cache, err = NewInputCache(s.model.Backend(), kvSize, s.parallel, multiUserCache)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.status = ServerStatusReady
|
||||||
|
s.ready.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Execute(args []string) error {
|
||||||
|
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
||||||
|
mpath := fs.String("model", "", "Path to model binary file")
|
||||||
|
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
||||||
|
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||||
|
_ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||||
|
_ = fs.Int("main-gpu", 0, "Main GPU")
|
||||||
|
_ = fs.Bool("flash-attn", false, "Enable flash attention")
|
||||||
|
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||||
|
_ = fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||||
|
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||||
|
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||||
|
verbose := fs.Bool("verbose", false, "verbose output (default: disabled)")
|
||||||
|
_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
|
||||||
|
_ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
|
||||||
|
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||||
|
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||||
|
|
||||||
|
var lpaths multiLPath
|
||||||
|
fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||||
|
|
||||||
|
fs.Usage = func() {
|
||||||
|
fmt.Fprintf(fs.Output(), "Runner usage\n")
|
||||||
|
fs.PrintDefaults()
|
||||||
|
}
|
||||||
|
if err := fs.Parse(args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
level := slog.LevelInfo
|
||||||
|
if *verbose {
|
||||||
|
level = slog.LevelDebug
|
||||||
|
}
|
||||||
|
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||||
|
Level: level,
|
||||||
|
AddSource: true,
|
||||||
|
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
||||||
|
if attr.Key == slog.SourceKey {
|
||||||
|
source := attr.Value.Any().(*slog.Source)
|
||||||
|
source.File = filepath.Base(source.File)
|
||||||
|
}
|
||||||
|
return attr
|
||||||
|
},
|
||||||
|
})
|
||||||
|
slog.SetDefault(slog.New(handler))
|
||||||
|
slog.Info("starting ollama engine")
|
||||||
|
//slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads)
|
||||||
|
|
||||||
|
server := &Server{
|
||||||
|
batchSize: *batchSize,
|
||||||
|
parallel: *parallel,
|
||||||
|
seqs: make([]*Sequence, *parallel),
|
||||||
|
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
||||||
|
status: ServerStatusLoadingModel,
|
||||||
|
}
|
||||||
|
|
||||||
|
var tensorSplitFloats []float32
|
||||||
|
if *tensorSplit != "" {
|
||||||
|
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
|
||||||
|
|
||||||
|
tensorSplitFloats = make([]float32, 0, len(stringFloats))
|
||||||
|
for _, s := range stringFloats {
|
||||||
|
f, _ := strconv.ParseFloat(s, 32)
|
||||||
|
tensorSplitFloats = append(tensorSplitFloats, float32(f))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*params := llama.ModelParams{
|
||||||
|
NumGpuLayers: *nGpuLayers,
|
||||||
|
MainGpu: *mainGpu,
|
||||||
|
UseMmap: !*noMmap && lpaths.String() == "",
|
||||||
|
UseMlock: *mlock,
|
||||||
|
TensorSplit: tensorSplitFloats,
|
||||||
|
Progress: func(progress float32) {
|
||||||
|
server.progress = progress
|
||||||
|
},
|
||||||
|
}*/
|
||||||
|
|
||||||
|
server.ready.Add(1)
|
||||||
|
go server.loadModel(*mpath, *kvSize, *threads, *multiUserCache)
|
||||||
|
|
||||||
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go server.run(ctx)
|
||||||
|
|
||||||
|
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||||
|
listener, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Listen error:", err)
|
||||||
|
cancel()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/embedding", server.embeddings)
|
||||||
|
mux.HandleFunc("/completion", server.completion)
|
||||||
|
mux.HandleFunc("/health", server.health)
|
||||||
|
|
||||||
|
httpServer := http.Server{
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("Server listening on", addr)
|
||||||
|
if err := httpServer.Serve(listener); err != nil {
|
||||||
|
log.Fatal("server error:", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package runner
|
package oldrunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
292
runner/oldrunner/cache_test.go
Normal file
292
runner/oldrunner/cache_test.go
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
package oldrunner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCountCommon(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
t1 []input
|
||||||
|
t2 []input
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Equal",
|
||||||
|
t1: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||||
|
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||||
|
expected: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Prefix",
|
||||||
|
t1: []input{{token: 1}},
|
||||||
|
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||||
|
expected: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Embeddings Prefix",
|
||||||
|
t1: []input{{embed: []float32{0.1, 0.2, 0.3}}},
|
||||||
|
t2: []input{{embed: []float32{0.1, 0.2, 0.3}}, {embed: []float32{0.4, 0.5, 0.6}}, {embed: []float32{0.7}}},
|
||||||
|
expected: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Embeddings Prefix Partial",
|
||||||
|
t1: []input{{embed: []float32{0.1, 0.2, 0.3}}},
|
||||||
|
t2: []input{{embed: []float32{0.1, 0.2}}, {embed: []float32{0.4, 0.5, 0.6}}, {embed: []float32{0.7}}},
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed",
|
||||||
|
t1: []input{{token: 1}, {embed: []float32{0.2, 0.3, 0.4}}},
|
||||||
|
t2: []input{{token: 1}, {embed: []float32{0.2, 0.3, 0.4}}, {token: 5}},
|
||||||
|
expected: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty",
|
||||||
|
t1: []input{},
|
||||||
|
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Both Empty",
|
||||||
|
t1: []input{},
|
||||||
|
t2: []input{},
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := countCommonPrefix(tt.t1, tt.t2)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindCacheSlot(t *testing.T) {
|
||||||
|
type expected struct {
|
||||||
|
result int
|
||||||
|
len int
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cache InputCache
|
||||||
|
prompt []input
|
||||||
|
longest expected
|
||||||
|
best expected
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty",
|
||||||
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
|
{
|
||||||
|
Id: 0,
|
||||||
|
Inputs: []input{},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Time{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Inputs: []input{},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Time{},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
prompt: []input{{token: 1}},
|
||||||
|
longest: expected{result: 0, len: 0},
|
||||||
|
best: expected{result: 0, len: 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extend",
|
||||||
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
|
{
|
||||||
|
Id: 0,
|
||||||
|
Inputs: []input{{token: 1}},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Inputs: []input{{token: 1}, {token: 2}},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
prompt: []input{{token: 1}, {token: 2}},
|
||||||
|
longest: expected{result: 1, len: 2},
|
||||||
|
best: expected{result: 1, len: 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New",
|
||||||
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
|
{
|
||||||
|
Id: 0,
|
||||||
|
Inputs: []input{{token: 1}, {token: 2}},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Inputs: []input{},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Time{},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
prompt: []input{{token: 2}},
|
||||||
|
longest: expected{result: 0, len: 0},
|
||||||
|
best: expected{result: 1, len: 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Fork",
|
||||||
|
cache: InputCache{
|
||||||
|
slots: []InputCacheSlot{
|
||||||
|
{
|
||||||
|
Id: 0,
|
||||||
|
Inputs: []input{{token: 1}, {token: 2}},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Inputs: []input{},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Time{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
prompt: []input{{token: 1}},
|
||||||
|
longest: expected{result: 0, len: 1},
|
||||||
|
best: expected{result: 1, len: 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Evict",
|
||||||
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
|
{
|
||||||
|
Id: 0,
|
||||||
|
Inputs: []input{{token: 1}},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Inputs: []input{{token: 1}, {token: 2}},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
prompt: []input{{token: 2}, {token: 3}},
|
||||||
|
longest: expected{result: 0, len: 0},
|
||||||
|
best: expected{result: 1, len: 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "In use",
|
||||||
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
|
{
|
||||||
|
Id: 0,
|
||||||
|
Inputs: []input{{token: 1}, {token: 2}},
|
||||||
|
InUse: true,
|
||||||
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Inputs: []input{{token: 1}},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
prompt: []input{{token: 1}, {token: 2}},
|
||||||
|
longest: expected{result: 1, len: 1},
|
||||||
|
best: expected{result: 1, len: 2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run("Longest-"+tt.name, func(t *testing.T) {
|
||||||
|
result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("findLongestCacheSlot: err %v", err)
|
||||||
|
} else if result.Id != tt.longest.result || resultLen != tt.longest.len {
|
||||||
|
t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v",
|
||||||
|
result.Id, tt.longest.result, resultLen, tt.longest.len)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run("Best-"+tt.name, func(t *testing.T) {
|
||||||
|
result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("findBestCacheSlot: err %v", err)
|
||||||
|
} else if result.Id != tt.best.result || resultLen != tt.best.len {
|
||||||
|
t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v",
|
||||||
|
result.Id, tt.best.result, resultLen, tt.best.len)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShiftDiscard(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
numCtx int
|
||||||
|
numKeep int
|
||||||
|
inputLen int
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Shift",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 5,
|
||||||
|
inputLen: 2048,
|
||||||
|
expected: 1021,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max Keep",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 2047,
|
||||||
|
inputLen: 2048,
|
||||||
|
expected: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No Keep",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 0,
|
||||||
|
inputLen: 2048,
|
||||||
|
expected: 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Truncate",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 5,
|
||||||
|
inputLen: 5000,
|
||||||
|
expected: 3973,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Truncate Keep",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 2047,
|
||||||
|
inputLen: 5000,
|
||||||
|
expected: 2953,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No Op",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 5,
|
||||||
|
inputLen: 512,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := InputCache{numCtx: tt.numCtx}
|
||||||
|
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package runner
|
package oldrunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
@ -1,4 +1,4 @@
|
|||||||
package runner
|
package oldrunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
@ -1,4 +1,4 @@
|
|||||||
package runner
|
package oldrunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -24,6 +24,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
|
"github.com/ollama/ollama/runner/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
// input is an element of the prompt to process, either
|
// input is an element of the prompt to process, either
|
||||||
@ -498,12 +499,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
sequence := strings.Join(seq.pendingResponses, "")
|
||||||
|
|
||||||
if ok, stop := findStop(sequence, seq.stop); ok {
|
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||||
|
|
||||||
var tokenTruncated bool
|
var tokenTruncated bool
|
||||||
origLen := len(seq.pendingResponses)
|
origLen := len(seq.pendingResponses)
|
||||||
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
|
seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
|
||||||
newLen := len(seq.pendingResponses)
|
newLen := len(seq.pendingResponses)
|
||||||
|
|
||||||
// Update the cache based on the tokens that will be returned:
|
// Update the cache based on the tokens that will be returned:
|
||||||
@ -524,11 +525,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if containsStopSuffix(sequence, seq.stop) {
|
if common.ContainsStopSuffix(sequence, seq.stop) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if incompleteUnicode(sequence) {
|
if common.IncompleteUnicode(sequence) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -885,9 +886,6 @@ func (s *Server) loadModel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Execute(args []string) error {
|
func Execute(args []string) error {
|
||||||
if args[0] == "runner" {
|
|
||||||
args = args[1:]
|
|
||||||
}
|
|
||||||
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
||||||
mpath := fs.String("model", "", "Path to model binary file")
|
mpath := fs.String("model", "", "Path to model binary file")
|
||||||
ppath := fs.String("mmproj", "", "Path to projector binary file")
|
ppath := fs.String("mmproj", "", "Path to projector binary file")
|
24
runner/runner.go
Normal file
24
runner/runner.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/runner/newrunner"
|
||||||
|
"github.com/ollama/ollama/runner/oldrunner"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Execute(args []string) error {
|
||||||
|
if args[0] == "runner" {
|
||||||
|
args = args[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
var newRunner bool
|
||||||
|
if args[0] == "--new-runner" {
|
||||||
|
args = args[1:]
|
||||||
|
newRunner = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if newRunner {
|
||||||
|
return newrunner.Execute(args)
|
||||||
|
} else {
|
||||||
|
return oldrunner.Execute(args)
|
||||||
|
}
|
||||||
|
}
|
@ -10,6 +10,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/model/mllama"
|
"github.com/ollama/ollama/model/mllama"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
@ -92,6 +93,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
var imgData llm.ImageData
|
var imgData llm.ImageData
|
||||||
|
|
||||||
if isMllama {
|
if isMllama {
|
||||||
|
if envconfig.NewRunners() {
|
||||||
|
imgData = llm.ImageData{
|
||||||
|
ID: len(images),
|
||||||
|
Data: i,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
data, opts, err := mllama.Preprocess(bytes.NewReader(i))
|
data, opts, err := mllama.Preprocess(bytes.NewReader(i))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
@ -113,6 +120,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
Data: buf.Bytes(),
|
Data: buf.Bytes(),
|
||||||
AspectRatioID: ar,
|
AspectRatioID: ar,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
imgPrompt = "<|image|>"
|
imgPrompt = "<|image|>"
|
||||||
} else {
|
} else {
|
||||||
imgData = llm.ImageData{
|
imgData = llm.ImageData{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user