From caddb1e4cfee951f89bca301ad89bc6ae59e7ce7 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 22 Mar 2025 10:15:52 -0700 Subject: [PATCH] rebased --- model/models/gemma3/process_image.go | 2 +- model/models/mistral3/model.go | 30 ++++------------------------ model/models/mistral3/model_text.go | 2 +- 3 files changed, 6 insertions(+), 28 deletions(-) diff --git a/model/models/gemma3/process_image.go b/model/models/gemma3/process_image.go index 1dc7259f9..fe8269a3b 100644 --- a/model/models/gemma3/process_image.go +++ b/model/models/gemma3/process_image.go @@ -51,7 +51,7 @@ func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 { func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) { outputSize := image.Point{p.imageSize, p.imageSize} newImage := imageproc.Composite(img) - newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBicubic) + newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear) data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD) return data, nil diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index c5e484e66..80e8f381e 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -46,13 +46,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return nil, model.ErrNoVisionModel } - // Decode image image, _, err := image.Decode(bytes.NewReader(multimodalData)) if err != nil { return nil, err } - // Process image f32s, err := m.ImageProcessor.ProcessImage(image) if err != nil { return nil, err @@ -100,38 +98,18 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { return result, nil } -func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { - inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err } - positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) + outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) if err != nil { return nil, err } - outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) - if err != nil { - return nil, err - } - - // Handle multimodal inputs - // var except []int - // hiddenState := m.TextModel.TokenEmbedding.Forward(ctx, inputs) - - // for _, image := range opts.Multimodal { - // visionOutputs := image.Multimodal.(ml.Tensor) - - // // Copy vision outputs into the hidden state - // ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) - - // for i := range visionOutputs.Dim(1) { - // except = append(except, image.Index+i) - // } - // } - - return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index f4382b4c0..52cd50b86 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -116,7 +116,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten return hiddenState.Add(ctx, residual) } -func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor { +func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { // Process text inputs hiddenState := m.TokenEmbedding.Forward(ctx, inputs)