diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index c7b56890e..fbbd9d5ce 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) { DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) } +func TestIntegrationSplitBatch(t *testing.T) { + image, err := base64.StdEncoding.DecodeString(imageEncoding) + require.NoError(t, err) + req := api.GenerateRequest{ + Model: "gemma3:4b", + // Fill up a chunk of the batch so the image will partially spill over into the next one + System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.", + Prompt: "what does the text in this image say?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + Images: []api.ImageData{ + image, + }, + } + + // Note: sometimes it returns "the ollamas" sometimes "the ollams" + resp := "the ollam" + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + require.NoError(t, PullIfMissing(ctx, client, req.Model)) + // llava models on CPU can be quite slow to start, + DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second) +} + const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6 diff --git a/model/input/input.go b/model/input/input.go index 0cb3f3f41..30bdcf065 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -15,6 +15,12 @@ type Input struct { // stored in Multimodal, used for caching and comparing // equality. MultimodalHash uint64 + + // SameBatch forces the following number of tokens to be processed + // in a single batch, breaking and extending batches as needed. + // Useful for things like images that must be processed in one + // shot. + SameBatch int } // MultimodalIndex is a multimodal element (such as an image) diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 24193f15f..ccc7567c5 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -2,10 +2,9 @@ package gemma3 import ( "bytes" - "encoding/binary" - "hash/fnv" "image" "math" + "slices" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" @@ -112,36 +111,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return visionOutputs, nil } -type imageToken struct { - embedding ml.Tensor - index int -} - func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { var result []input.Input - fnvHash := fnv.New64a() for _, inp := range inputs { if inp.Multimodal == nil { result = append(result, inp) } else { - imageInputs := []input.Input{ - {Token: 108}, // "\n\n" - {Token: 255999}, // """ - } - result = append(result, imageInputs...) - - // add image embeddings inputMultimodal := inp.Multimodal.(ml.Tensor) - for i := range inputMultimodal.Dim(1) { - fnvHash.Reset() - binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash) - fnvHash.Write([]byte{byte(i)}) + result = append(result, + input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" + input.Input{Token: 255999}, // """ + input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder + ) - imageToken := imageToken{embedding: inputMultimodal, index: i} - result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()}) - } + // add image token placeholders + result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) result = append(result, input.Input{Token: 256000}, // diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 7a88c0921..567f65a5e 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -171,53 +171,20 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, return hiddenState.Add(ctx, residual) } -func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int { - var embedding ml.Tensor - var src, dst, length int - var except []int - - for _, image := range multimodal { - imageToken := image.Multimodal.(imageToken) - imageSrc := imageToken.index - imageDst := image.Index - - if embedding == nil { - embedding = imageToken.embedding - src = imageSrc - dst = imageDst - length = 1 - } else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst { - src = imageSrc - dst = imageDst - length++ - } else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst { - length++ - } else { - visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0)) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0)))) - - embedding = imageToken.embedding - src = imageSrc - dst = imageDst - length = 1 - } - - except = append(except, imageDst) - } - - if embedding != nil { - visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0)) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0)))) - } - - return except -} - func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) - except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal) + // set image embeddings + var except []int + for _, image := range opts.Multimodal { + visionOutputs := image.Multimodal.(ml.Tensor) + ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + + for i := range visionOutputs.Dim(1) { + except = append(except, image.Index+i) + } + } for i, layer := range m.Layers { // gemma alternates between the sliding window (local) and causal (global) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index d6339a615..916ad45da 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -352,6 +352,8 @@ func (s *Server) processBatch() error { seq.cache.Inputs = []input.Input{} } + batchSize := s.batchSize + for j, inp := range seq.inputs { if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { if len(seq.pendingInputs) == 0 { @@ -364,7 +366,15 @@ func (s *Server) processBatch() error { } } - if j >= s.batchSize { + // If we are required to put following inputs into a single batch then extend the + // batch size. Since we are only extending the size the minimum amount possible, this + // will cause a break if we have pending inputs. + minBatch := 1 + inp.SameBatch + if minBatch > batchSize { + batchSize = minBatch + } + + if len(seq.pendingInputs)+minBatch > batchSize { break }