diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 07bc788b4..af01bb6f1 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -424,6 +424,17 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { return out, nil } +func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { + s := make([]float32, 0, int((stop-start)/step)) + for i := start; i < stop; i += step { + s = append(s, i) + } + + out, _ := c.FromFloatSlice(s, len(s)) + out.(*testTensor).dtype = dtype + return out +} + func (c *testContext) Input() ml.Context { return c } func (c *testContext) Layer(int) ml.Context { return c } diff --git a/ml/backend.go b/ml/backend.go index b2a83cfd5..70c2fd8e2 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -95,6 +95,9 @@ type Context interface { FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) (Tensor, error) + // Arange creates a 1D tensor with values within an interval (start, stop] increased by step. + Arange(start, stop, step float32, dtype DType) Tensor + Forward(...Tensor) Context Compute(...Tensor) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 94fc87a3d..c486b7477 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -696,6 +696,32 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { return t, nil } +func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { + switch dtype { + case ml.DTypeF32: + // ggml_arange creates a float32 tensor + return &Tensor{ + b: c.b, + t: C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)), + } + case ml.DTypeI32: + // ggml_cast does not support float32 to int32 conversion + arange := make([]int32, 0, int((stop-start)/step)) + for i := start; i < stop; i += step { + arange = append(arange, int32(i)) + } + + t, err := c.Input().FromIntSlice(arange, len(arange)) + if err != nil { + panic(err) + } + + return t + default: + panic("unsupported dtype for arange") + } +} + func (c *Context) Close() { if c != nil { for _, b := range *c.allocatedBuffers { diff --git a/model/models/gemma3/model_vision.go b/model/models/gemma3/model_vision.go index 636a363df..8b1a8eb00 100644 --- a/model/models/gemma3/model_vision.go +++ b/model/models/gemma3/model_vision.go @@ -92,16 +92,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - positions := make([]int32, numPatches) - for i := range positions { - positions[i] = int32(i) - } - - positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) - if err != nil { - panic(err) - } - + positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeI32) hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs)) for _, layer := range m.Layers { diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index e53eb184c..a0fc6b693 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -93,16 +93,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return nil, err } - positions := make([]int32, 1601) - for i := range positions { - positions[i] = int32(i) - } - - positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) - if err != nil { - return nil, err - } - + positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) return m.Projector.Forward(ctx, crossAttentionStates), nil }