diff --git a/model/input/input.go b/model/input/input.go index 0cb3f3f41..a1247bca7 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -15,6 +15,12 @@ type Input struct { // stored in Multimodal, used for caching and comparing // equality. MultimodalHash uint64 + + // BatchBreak forces a new batch to be started with this + // input. For example, this can be used to align images + // with batches. Note that batches may be divided in additional + // locations as well. + BatchBreak bool } // MultimodalIndex is a multimodal element (such as an image) diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 7418bb12f..2fe04348d 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -112,8 +112,8 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu result = append(result, inp) } else { imageInputs := []input.Input{ - {Token: 108}, // "\n\n" - {Token: 255999}, // """ + {Token: 108}, // "\n\n" + {Token: 255999, BatchBreak: true}, // """ } result = append(result, imageInputs...) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index c1475cbb2..9b997bd37 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -363,7 +363,7 @@ func (s *Server) processBatch() error { } } - if j >= s.batchSize { + if j >= s.batchSize || (inp.BatchBreak && len(seq.pendingInputs) != 0) { break } diff --git a/server/prompt.go b/server/prompt.go index d053f2a8d..5b5b958f1 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -26,7 +26,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. var system []api.Message isMllama := checkMllamaModelFamily(m) - isGemma3 := checkGemma3ModelFamily(m) var imageNumTokens int // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent @@ -41,7 +40,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. n := len(msgs) - 1 // in reverse, find all messages that fit into context window for i := n; i >= 0; i-- { - if (isMllama || isGemma3) && len(msgs[i].Images) > 1 { + if isMllama && len(msgs[i].Images) > 1 { return "", nil, errTooManyImages } @@ -158,12 +157,3 @@ func checkMllamaModelFamily(m *Model) bool { } return false } - -func checkGemma3ModelFamily(m *Model) bool { - for _, arch := range m.Config.ModelFamilies { - if arch == "gemma3" { - return true - } - } - return false -}