Compare commits
4 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7fa9694359 | ||
![]() |
96510b9353 | ||
![]() |
9f8c89354b | ||
![]() |
8815a8ee25 |
3
.gitignore
vendored
3
.gitignore
vendored
@ -14,3 +14,6 @@ test_data
|
|||||||
__debug_bin*
|
__debug_bin*
|
||||||
llama/build
|
llama/build
|
||||||
llama/vendor
|
llama/vendor
|
||||||
|
model/testdata/models/*
|
||||||
|
!model/testdata/models/*.md
|
||||||
|
!model/testdata/models/*.json
|
||||||
|
@ -430,7 +430,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
|
|||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
|
func (t *testTensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,6 +43,42 @@ func NewBackend(f *os.File) (Backend, error) {
|
|||||||
return nil, fmt.Errorf("unsupported backend")
|
return nil, fmt.Errorf("unsupported backend")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RopeType specifies the type of RoPE (Rotary Position Embedding) to use, these types are implemented in the backend
|
||||||
|
type RopeType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
RopeTypeStandard RopeType = iota
|
||||||
|
_ // not yet used
|
||||||
|
RopeTypeNeoX
|
||||||
|
)
|
||||||
|
|
||||||
|
// RopeConfig contains all configuration for the RoPE (Rotary Position Embedding) operation
|
||||||
|
type RopeConfig struct {
|
||||||
|
// PositionIDs contains the position indices for each token in the sequence
|
||||||
|
// These indices are used to calculate the rotary embeddings
|
||||||
|
PositionIDs Tensor
|
||||||
|
|
||||||
|
// RopeFactors is an optional tensor containing pre-computed rotation factors
|
||||||
|
RopeFactors Tensor
|
||||||
|
|
||||||
|
// RopeDim specifies the dimension size for the rotary embeddings
|
||||||
|
RopeDim uint32
|
||||||
|
|
||||||
|
// RopeType indicates which RoPE variant to use (e.g. normal or neox)
|
||||||
|
RopeType RopeType
|
||||||
|
|
||||||
|
// OrigCtxLen stores the original context length the model was trained with
|
||||||
|
OrigCtxLen int
|
||||||
|
|
||||||
|
// RopeBase is the base value used in the frequency calculation
|
||||||
|
RopeBase float32
|
||||||
|
|
||||||
|
// RopeScale is a scaling factor applied to position indices
|
||||||
|
RopeScale float32
|
||||||
|
|
||||||
|
// YaRN parameters can be added here if they need to be configurable
|
||||||
|
}
|
||||||
|
|
||||||
type Context interface {
|
type Context interface {
|
||||||
Zeros(dtype DType, shape ...int) Tensor
|
Zeros(dtype DType, shape ...int) Tensor
|
||||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||||
@ -75,7 +111,7 @@ type Tensor interface {
|
|||||||
Scale(ctx Context, s float64) Tensor
|
Scale(ctx Context, s float64) Tensor
|
||||||
|
|
||||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
|
RoPE(ctx Context, rc RopeConfig) Tensor
|
||||||
|
|
||||||
Tanh(ctx Context) Tensor
|
Tanh(ctx Context) Tensor
|
||||||
GELU(ctx Context) Tensor
|
GELU(ctx Context) Tensor
|
||||||
|
@ -579,13 +579,9 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
func (t *Tensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor {
|
||||||
ropeTypeNorm C.int = iota
|
if rc.RopeFactors == nil {
|
||||||
)
|
rc.RopeFactors = &Tensor{}
|
||||||
|
|
||||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
|
|
||||||
if ropeFactors == nil {
|
|
||||||
ropeFactors = &Tensor{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dequant := t.t
|
dequant := t.t
|
||||||
@ -595,12 +591,15 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
|||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
t: C.ggml_rope_ext(
|
t: C.ggml_rope_ext(
|
||||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
ctx.(*Context).ctx,
|
||||||
C.int(ropeDim),
|
dequant,
|
||||||
131072, // YaRN n_ctx_train
|
rc.PositionIDs.(*Tensor).t,
|
||||||
ropeTypeNorm, // ROPE_TYPE_NORM
|
rc.RopeFactors.(*Tensor).t,
|
||||||
C.float(ropeBase),
|
C.int(rc.RopeDim),
|
||||||
C.float(ropeScale),
|
C.int(rc.RopeType),
|
||||||
|
C.int(rc.OrigCtxLen),
|
||||||
|
C.float(rc.RopeBase),
|
||||||
|
C.float(rc.RopeScale),
|
||||||
0., // YaRN ext_factor
|
0., // YaRN ext_factor
|
||||||
1., // YaRN attn_factor
|
1., // YaRN attn_factor
|
||||||
32., // YaRN beta_fast
|
32., // YaRN beta_fast
|
||||||
|
138
model/model_external_test.go
Normal file
138
model/model_external_test.go
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
// Package model_test provides external tests for the model package.
|
||||||
|
// This test file specifically tests the forward pass functionality on models.
|
||||||
|
// It is in a separate package (model_test) to avoid import cycles while still
|
||||||
|
// being able to test the public API of the model package.
|
||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/sample"
|
||||||
|
|
||||||
|
_ "github.com/ollama/ollama/model/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
type modelTest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
OutputContainsOne []string `json:"output_contains_one"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardSimple(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read all JSON files from testdata/models
|
||||||
|
files, err := os.ReadDir("testdata/models")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
if !strings.HasSuffix(file.Name(), ".json") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonPath := filepath.Join("testdata/models", file.Name())
|
||||||
|
ggufPath := filepath.Join("testdata/models", strings.TrimSuffix(file.Name(), ".json")+".gguf")
|
||||||
|
|
||||||
|
// Skip if no corresponding .gguf file exists
|
||||||
|
if _, err := os.Stat(ggufPath); err != nil {
|
||||||
|
t.Logf("skipping %s: no corresponding GGUF file found", file.Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(jsonPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var test modelTest
|
||||||
|
if err := json.Unmarshal(data, &test); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run(strings.TrimSuffix(file.Name(), ".json"), func(t *testing.T) {
|
||||||
|
m, err := model.New(ggufPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Config().Cache.Init(m.Backend(), ml.DTypeF32, 2048)
|
||||||
|
|
||||||
|
inputs, err := m.(model.TextProcessor).Encode(test.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
for len(result) < 100 { // Limit to 100 tokens max
|
||||||
|
options := model.Options{
|
||||||
|
Inputs: inputs,
|
||||||
|
Positions: make([]int32, len(inputs)),
|
||||||
|
Sequences: make([]int, len(inputs)),
|
||||||
|
Outputs: []int32{int32(len(inputs) - 1)},
|
||||||
|
}
|
||||||
|
for i := range options.Positions {
|
||||||
|
options.Positions[i] = int32(i)
|
||||||
|
options.Sequences[i] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := m.Backend().NewContext()
|
||||||
|
|
||||||
|
modelOutput, err := model.Forward(ctx, m, options)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Close()
|
||||||
|
t.Fatal(fmt.Errorf("forward pass failed: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := modelOutput.Floats()
|
||||||
|
logits := make([]float64, len(f32s))
|
||||||
|
for i, f32 := range f32s {
|
||||||
|
logits[i] = float64(f32)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := sample.Sample(logits, sample.Greedy())
|
||||||
|
if err != nil {
|
||||||
|
ctx.Close()
|
||||||
|
t.Fatal(fmt.Errorf("sampling failed: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Close()
|
||||||
|
|
||||||
|
// Greedy sampling: take the token with the highest logit
|
||||||
|
nextToken := int32(token[0])
|
||||||
|
if m.(model.TextProcessor).Is(nextToken, model.SpecialEOS) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
piece, err := m.(model.TextProcessor).Decode([]int32{nextToken})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, piece)
|
||||||
|
output := strings.Join(result, "")
|
||||||
|
|
||||||
|
for _, expectedOutput := range test.OutputContainsOne {
|
||||||
|
if strings.Contains(output, expectedOutput) {
|
||||||
|
t.Logf("Test passed with output: %q (matched expected: %q)", output, expectedOutput)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maintain full context by appending new token
|
||||||
|
inputs = append(inputs, nextToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatalf("Expected output containing one of %q but got: %q", test.OutputContainsOne, strings.Join(result, ""))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -10,10 +10,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
hiddenSize, numHeads, numKVHeads int
|
ctxLen, hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
ropeDim uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@ -46,6 +46,7 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ctxLen: int(c.Uint("context_length")),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
ropeDim: c.Uint("rope.dimension_count"),
|
ropeDim: c.Uint("rope.dimension_count"),
|
||||||
@ -67,14 +68,23 @@ type SelfAttention struct {
|
|||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
|
rc := ml.RopeConfig{
|
||||||
|
PositionIDs: positionIDs,
|
||||||
|
RopeFactors: opts.RopeFactors,
|
||||||
|
RopeDim: opts.ropeDim,
|
||||||
|
RopeType: ml.RopeTypeStandard,
|
||||||
|
OrigCtxLen: opts.ctxLen,
|
||||||
|
RopeBase: opts.ropeBase,
|
||||||
|
RopeScale: opts.ropeScale,
|
||||||
|
}
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, rc)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, rc)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -99,7 +109,18 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
|
return key.RoPE(
|
||||||
|
ctx,
|
||||||
|
ml.RopeConfig{
|
||||||
|
PositionIDs: shift,
|
||||||
|
RopeFactors: m.Options.RopeFactors,
|
||||||
|
RopeDim: m.Options.ropeDim,
|
||||||
|
RopeType: ml.RopeTypeStandard,
|
||||||
|
OrigCtxLen: m.Options.ctxLen,
|
||||||
|
RopeBase: m.Options.ropeBase,
|
||||||
|
RopeScale: m.Options.ropeScale,
|
||||||
|
},
|
||||||
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
@ -19,14 +19,23 @@ type TextSelfAttention struct {
|
|||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
|
rc := ml.RopeConfig{
|
||||||
|
PositionIDs: positions,
|
||||||
|
RopeFactors: opts.RopeFactors,
|
||||||
|
RopeDim: opts.ropeDim,
|
||||||
|
RopeType: ml.RopeTypeStandard,
|
||||||
|
OrigCtxLen: opts.ctxLen,
|
||||||
|
RopeBase: opts.ropeBase,
|
||||||
|
RopeScale: opts.ropeScale,
|
||||||
|
}
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
query = query.RoPE(ctx, rc)
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
key = key.RoPE(ctx, rc)
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -52,7 +61,18 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the cache, which are just the self attention layers
|
// This will only get called for layers in the cache, which are just the self attention layers
|
||||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
return key.RoPE(
|
||||||
|
ctx,
|
||||||
|
ml.RopeConfig{
|
||||||
|
PositionIDs: shift,
|
||||||
|
RopeFactors: m.RopeFactors,
|
||||||
|
RopeDim: m.ropeDim,
|
||||||
|
RopeType: ml.RopeTypeStandard,
|
||||||
|
OrigCtxLen: m.ctxLen,
|
||||||
|
RopeBase: m.ropeBase,
|
||||||
|
RopeScale: m.ropeScale,
|
||||||
|
},
|
||||||
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
@ -189,9 +209,9 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
|
|||||||
type TextModelOptions struct {
|
type TextModelOptions struct {
|
||||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
|
|
||||||
hiddenSize, numHeads, numKVHeads int
|
ctxLen, hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
ropeDim uint32
|
||||||
|
|
||||||
crossAttentionLayers []uint32
|
crossAttentionLayers []uint32
|
||||||
}
|
}
|
||||||
|
@ -3,4 +3,5 @@ package models
|
|||||||
import (
|
import (
|
||||||
_ "github.com/ollama/ollama/model/models/llama"
|
_ "github.com/ollama/ollama/model/models/llama"
|
||||||
_ "github.com/ollama/ollama/model/models/mllama"
|
_ "github.com/ollama/ollama/model/models/mllama"
|
||||||
|
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||||
)
|
)
|
||||||
|
222
model/models/qwen2/model.go
Normal file
222
model/models/qwen2/model.go
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package qwen2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
|
contextLength int
|
||||||
|
hiddenSize int
|
||||||
|
numAttnHeads int
|
||||||
|
numKVHeads int
|
||||||
|
modelEpsilon float32
|
||||||
|
ropeBaseFreq float32
|
||||||
|
ropeFreqScale float32
|
||||||
|
ropeDimensions uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.BytePairEncoding
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
Layers []Layer `gguf:"blk"`
|
||||||
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
|
*Options
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(c ml.Config) (model.Model, error) {
|
||||||
|
m := &Model{
|
||||||
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
|
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||||
|
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
|
Options: &Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numAttnHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
modelEpsilon: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
contextLength: int(c.Uint("context_length")),
|
||||||
|
ropeBaseFreq: c.Float("rope.freq_base"),
|
||||||
|
ropeFreqScale: c.Float("rope.freq_scale", 1),
|
||||||
|
ropeDimensions: c.Uint("rope.dimension_count", 64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||||
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
return key.RoPE(
|
||||||
|
ctx,
|
||||||
|
ml.RopeConfig{
|
||||||
|
PositionIDs: shift,
|
||||||
|
RopeFactors: m.Options.RopeFactors,
|
||||||
|
RopeDim: m.Options.ropeDimensions,
|
||||||
|
RopeType: ml.RopeTypeNeoX,
|
||||||
|
OrigCtxLen: m.Options.contextLength,
|
||||||
|
RopeBase: m.Options.ropeBaseFreq,
|
||||||
|
RopeScale: m.Options.ropeFreqScale,
|
||||||
|
},
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelfAttention implements the multi-head self-attention mechanism
|
||||||
|
// with separate projections for query, key, value and output transformations
|
||||||
|
type SelfAttention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, inputPositions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
// Initialize dimensions and configuration
|
||||||
|
batchSize := hiddenState.Dim(1)
|
||||||
|
headDimension := opts.hiddenSize / opts.numAttnHeads
|
||||||
|
ropeConfig := ml.RopeConfig{
|
||||||
|
PositionIDs: inputPositions,
|
||||||
|
RopeFactors: nil,
|
||||||
|
RopeDim: opts.ropeDimensions,
|
||||||
|
RopeType: ml.RopeTypeNeoX,
|
||||||
|
OrigCtxLen: opts.contextLength,
|
||||||
|
RopeBase: opts.ropeBaseFreq,
|
||||||
|
RopeScale: opts.ropeFreqScale,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Project and reshape query states with rotary embeddings
|
||||||
|
queryStates := sa.Query.Forward(ctx, hiddenState)
|
||||||
|
queryStates = queryStates.Reshape(ctx, headDimension, opts.numAttnHeads, batchSize)
|
||||||
|
queryStates = queryStates.RoPE(ctx, ropeConfig)
|
||||||
|
|
||||||
|
// Project and reshape key states with rotary embeddings
|
||||||
|
keyStates := sa.Key.Forward(ctx, hiddenState)
|
||||||
|
keyStates = keyStates.Reshape(ctx, headDimension, opts.numKVHeads, batchSize)
|
||||||
|
keyStates = keyStates.RoPE(ctx, ropeConfig)
|
||||||
|
|
||||||
|
// Project and reshape value states
|
||||||
|
valueStates := sa.Value.Forward(ctx, hiddenState)
|
||||||
|
valueStates = valueStates.Reshape(ctx, headDimension, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
// Update and retrieve from KV cache
|
||||||
|
cache.Put(ctx, keyStates, valueStates)
|
||||||
|
keyStates, valueStates, attentionMask := cache.Get(ctx)
|
||||||
|
|
||||||
|
// Prepare tensors for attention computation
|
||||||
|
queryStates = queryStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
keyStates = keyStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
valueStates = valueStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
|
// Apply scaling and attention mask to scores
|
||||||
|
attentionScores := keyStates.MulmatFullPrec(ctx, queryStates)
|
||||||
|
attentionScores = attentionScores.Scale(ctx, 1.0/math.Sqrt(float64(headDimension)))
|
||||||
|
attentionScores = attentionScores.Add(ctx, attentionMask)
|
||||||
|
// Compute scaled dot-product attention
|
||||||
|
attentionProbs := attentionScores.Softmax(ctx)
|
||||||
|
|
||||||
|
// Apply attention weights and reshape
|
||||||
|
weightedStates := valueStates.Mulmat(ctx, attentionProbs)
|
||||||
|
weightedStates = weightedStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
weightedStates = weightedStates.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
|
||||||
|
// Project to output dimension
|
||||||
|
return sa.Output.Forward(ctx, weightedStates)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MLP implements the feed-forward network component with SwiGLU activation
|
||||||
|
type MLP struct {
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
// Apply SwiGLU activation gating
|
||||||
|
gateActivation := mlp.Gate.Forward(ctx, hiddenState).SILU(ctx)
|
||||||
|
upProjection := mlp.Up.Forward(ctx, hiddenState)
|
||||||
|
intermediateStates := gateActivation.Mul(ctx, upProjection)
|
||||||
|
|
||||||
|
// Project back to hidden dimension
|
||||||
|
return mlp.Down.Forward(ctx, intermediateStates)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layer represents a single transformer layer combining self-attention and feed-forward components
|
||||||
|
type Layer struct {
|
||||||
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||||
|
SelfAttention *SelfAttention
|
||||||
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||||
|
MLP *MLP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
// Self-attention branch with residual connection
|
||||||
|
residual := hiddenState
|
||||||
|
|
||||||
|
normalizedAttention := l.AttentionNorm.Forward(ctx, hiddenState, opts.modelEpsilon)
|
||||||
|
attentionOutput := l.SelfAttention.Forward(ctx, normalizedAttention, positionIDs, cache, opts)
|
||||||
|
hiddenState = attentionOutput.Add(ctx, residual)
|
||||||
|
|
||||||
|
// Feed-forward branch with residual connection
|
||||||
|
residual = hiddenState
|
||||||
|
normalizedMLP := l.MLPNorm.Forward(ctx, hiddenState, opts.modelEpsilon)
|
||||||
|
mlpOutput := l.MLP.Forward(ctx, normalizedMLP, opts)
|
||||||
|
output := mlpOutput.Add(ctx, residual)
|
||||||
|
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||||
|
// Convert input tokens and positions to tensors
|
||||||
|
inputTensor, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
positionsTensor, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initial token embedding
|
||||||
|
hiddenStates := m.TokenEmbedding.Forward(ctx, inputTensor)
|
||||||
|
|
||||||
|
// Process through transformer layers
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
hiddenStates = layer.Forward(ctx, hiddenStates, positionsTensor, m.Cache, m.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final layer normalization and output projection
|
||||||
|
normalizedOutput := m.OutputNorm.Forward(ctx, hiddenStates, m.modelEpsilon)
|
||||||
|
logits := m.Output.Forward(ctx, normalizedOutput)
|
||||||
|
|
||||||
|
// Extract requested output token positions
|
||||||
|
outputsTensor, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return logits.Rows(ctx, outputsTensor), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("qwen2", New)
|
||||||
|
}
|
10
model/testdata/models/README.md
vendored
Normal file
10
model/testdata/models/README.md
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Test Model Directory
|
||||||
|
|
||||||
|
This directory is used for storing model files (like `.gguf` files) that are required to run the tests in `model_external_test.go`.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
- Place any model files you need for testing in this directory
|
||||||
|
- The test file will look for any model files here (e.g., `llama3.gguf`)
|
||||||
|
- All non-markdown files in this directory are git-ignored to prevent large model files from being committed to the repository
|
||||||
|
- Only `.md` files (like this README) will be tracked in git
|
7
model/testdata/models/qwen2_5.json
vendored
Normal file
7
model/testdata/models/qwen2_5.json
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"prompt": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nhi<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
"output_contains_one": [
|
||||||
|
"Hello",
|
||||||
|
"Hi"
|
||||||
|
]
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user