Compare commits
3 Commits
main
...
qwen25omni
Author | SHA1 | Date | |
---|---|---|---|
![]() |
db1146f4e8 | ||
![]() |
8f9eafda06 | ||
![]() |
23267d783b |
@ -196,6 +196,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
conv = &qwen2Model{}
|
||||
case "Qwen2_5_VLForConditionalGeneration":
|
||||
conv = &qwen25VLModel{}
|
||||
case "Qwen2_5OmniModel":
|
||||
conv = &qwen25OmniModel{}
|
||||
case "BertModel":
|
||||
conv = &bertModel{}
|
||||
case "CohereForCausalLM":
|
||||
|
209
convert/convert_qwen25_omni.go
Normal file
209
convert/convert_qwen25_omni.go
Normal file
@ -0,0 +1,209 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen25OmniModel struct {
|
||||
ModelParameters
|
||||
TalkerModel struct {
|
||||
AudioEndTokenID uint32 `json:"audio_end_token_id"`
|
||||
AudioStartTokenID uint32 `json:"audio_start_token_id"`
|
||||
AudioTokenIndex uint32 `json:"audio_token_index"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
ImageTokenIndex uint32 `json:"image_token_index"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
MaxWindowLayers uint32 `json:"max_window_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
VideoTokenIndex uint32 `json:"video_token_index"`
|
||||
VisionEndTokenID uint32 `json:"vision_end_token_id"`
|
||||
VisionStartTokenID uint32 `json:"vision_start_token_id"`
|
||||
} `json:"talker_config"`
|
||||
|
||||
ThinkerModel struct {
|
||||
TextModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
} `json:"text_config"`
|
||||
} `json:"thinker_config"`
|
||||
|
||||
VisionModel struct {
|
||||
} `json:"vision_config"`
|
||||
|
||||
Token2WavModel struct {
|
||||
} `json:"token2wav_config"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen25OmniModel)(nil)
|
||||
|
||||
func (q *qwen25OmniModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25omni"
|
||||
kv["qwen25omni.block_count"] = q.ThinkerModel.TextModel.HiddenLayers
|
||||
kv["qwen25omni.context_length"] = q.ThinkerModel.TextModel.MaxPositionEmbeddings
|
||||
kv["qwen25omni.embedding_length"] = q.ThinkerModel.TextModel.HiddenSize
|
||||
kv["qwen25omni.feed_forward_length"] = q.ThinkerModel.TextModel.IntermediateSize
|
||||
kv["qwen25omni.attention.head_count"] = q.ThinkerModel.TextModel.NumAttentionHeads
|
||||
kv["qwen25omni.attention.head_count_kv"] = q.ThinkerModel.TextModel.NumKeyValueHeads
|
||||
kv["qwen25omni.rope.freq_base"] = q.ThinkerModel.TextModel.RopeTheta
|
||||
kv["qwen25omni.attention.layer_norm_rms_epsilon"] = q.ThinkerModel.TextModel.RMSNormEPS
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen25OmniModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), "patch_embed.proj.weight") {
|
||||
var buf bytes.Buffer
|
||||
t.WriteTo(&buf)
|
||||
newTensors := splitPatchEmbed(buf, t.Kind(), t.Shape())
|
||||
out = append(out, newTensors...)
|
||||
} else {
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func splitPatchEmbed(buf bytes.Buffer, kind uint32, shape []uint64) []ggml.Tensor {
|
||||
slog.Debug("patch stuff", "kind", kind, "shape", shape)
|
||||
|
||||
if kind != tensorKindF16 {
|
||||
panic("tensor is of wrong type")
|
||||
}
|
||||
|
||||
if len(shape) != 5 || (len(shape) == 5 && shape[2] != 2) {
|
||||
panic("wrong sized tensor")
|
||||
}
|
||||
|
||||
// determine the size of the tensor based on its shape
|
||||
shapeToSize := func(s []int) int {
|
||||
r := 1
|
||||
for _, n := range s {
|
||||
r *= int(n)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// tensor.WithShape() wants []int
|
||||
intShape := make([]int, len(shape))
|
||||
for i, v := range shape {
|
||||
intShape[i] = int(v)
|
||||
}
|
||||
|
||||
u16s := make([]uint16, shapeToSize(intShape))
|
||||
if err := binary.Read(&buf, binary.LittleEndian, u16s); err != nil {
|
||||
panic("bad read")
|
||||
}
|
||||
|
||||
f32s := make([]float32, len(u16s))
|
||||
for i := range u16s {
|
||||
f32s[i] = float16.Frombits(u16s[i]).Float32()
|
||||
}
|
||||
|
||||
newTensors := []ggml.Tensor{}
|
||||
|
||||
getDataFromSlice := func(f32s []float32, shape []int, s []tensor.Slice) patchEmbed {
|
||||
slog.Debug("getDataFromSlice", "num f32s", len(f32s), "shape", shape)
|
||||
n := tensor.New(tensor.WithShape(shape...), tensor.WithBacking(f32s))
|
||||
t, err := n.Slice(s...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(t.Materialize().(*tensor.Dense), 0)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
slog.Debug("first vals", "val 1", ts[0][0], "val 2", ts[0][1], "val 3", ts[0][2])
|
||||
|
||||
f16s := make(patchEmbed, shapeToSize(shape))
|
||||
for r, row := range ts {
|
||||
for c, col := range row {
|
||||
f16s[r+c] = float16.Fromfloat32(col).Bits()
|
||||
}
|
||||
}
|
||||
|
||||
return f16s
|
||||
}
|
||||
|
||||
p := getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(0, 1, 1), nil, nil})
|
||||
newTensors = append(newTensors, ggml.Tensor{
|
||||
Name: "patch_embed.proj.0.weight",
|
||||
Kind: kind,
|
||||
Shape: append(shape[:2], shape[3:]...),
|
||||
WriterTo: p,
|
||||
})
|
||||
|
||||
p = getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(1, 2, 1), nil, nil})
|
||||
newTensors = append(newTensors, ggml.Tensor{
|
||||
Name: "patch_embed.proj.1.weight",
|
||||
Kind: kind,
|
||||
Shape: append(shape[:2], shape[3:]...),
|
||||
WriterTo: p,
|
||||
})
|
||||
|
||||
return newTensors
|
||||
}
|
||||
|
||||
type patchEmbed []uint16
|
||||
|
||||
func (t patchEmbed) WriteTo(w io.Writer) (int64, error) {
|
||||
err := binary.Write(w, binary.LittleEndian, t)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
func (p *qwen25OmniModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"thinker.audio_tower.layers", "a.blk",
|
||||
"thinker.visual.blocks", "v.blk",
|
||||
"thinker.model.layers", "blk",
|
||||
"talker.model.layers", "tlk.blk",
|
||||
"token2wav.code2wav_bigvgan_model", "t2w.b",
|
||||
"token2wav.code2wav_dit_model", "t2w.d",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"model.norm", "output_norm",
|
||||
}
|
||||
}
|
81
convert/convert_qwen25_vl.go
Normal file
81
convert/convert_qwen25_vl.go
Normal file
@ -0,0 +1,81 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen25VLModel struct {
|
||||
ModelParameters
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
VisionModel struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
kv["qwen25vl.block_count"] = q.HiddenLayers
|
||||
kv["qwen25vl.context_length"] = q.MaxPositionEmbeddings
|
||||
kv["qwen25vl.embedding_length"] = q.HiddenSize
|
||||
kv["qwen25vl.feed_forward_length"] = q.IntermediateSize
|
||||
kv["qwen25vl.attention.head_count"] = q.NumAttentionHeads
|
||||
kv["qwen25vl.attention.head_count_kv"] = q.NumKeyValueHeads
|
||||
kv["qwen25vl.rope.freq_base"] = q.RopeTheta
|
||||
kv["qwen25vl.attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen25VLModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), "patch_embed.proj.weight") {
|
||||
var buf bytes.Buffer
|
||||
t.WriteTo(&buf)
|
||||
newTensors := splitPatchEmbed(buf, t.Kind(), t.Shape())
|
||||
out = append(out, newTensors...)
|
||||
} else {
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *qwen25VLModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"visual.blocks", "v.blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"model.norm", "output_norm",
|
||||
}
|
||||
}
|
@ -462,7 +462,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
|
@ -118,6 +118,53 @@ type Context interface {
|
||||
Layer(int) Context
|
||||
}
|
||||
|
||||
// RopeType represents different RoPE (Rotary Position Embedding) implementation types
|
||||
type RopeType int
|
||||
|
||||
// Available RoPE implementation types
|
||||
const (
|
||||
RopeTypeNormal RopeType = iota // Standard RoPE implementation
|
||||
RopeTypeNeox // NeoX-style RoPE implementation
|
||||
RopeTypeMRoPE // Multi-scale RoPE implementation
|
||||
RopeTypeVision // Vision-specific RoPE implementation
|
||||
)
|
||||
|
||||
type YarnConfig struct {
|
||||
YarnCtxTrain int // Context size used during training (for YaRN scaling)
|
||||
YarnExtFactor float32 // Extension factor for YaRN
|
||||
YarnAttnFactor float32 // Attention scaling factor for YaRN
|
||||
YarnBetaFast float32 // Fast decay parameter for YaRN
|
||||
YarnBetaSlow float32 // Slow decay parameter for YaRN
|
||||
}
|
||||
|
||||
// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Recurrent Network)
|
||||
func DefaultYarnConfig(nCtx int32) *YarnConfig {
|
||||
return &YarnConfig{
|
||||
YarnCtxTrain: int(nCtx),
|
||||
YarnExtFactor: 0.0,
|
||||
YarnAttnFactor: 1.0,
|
||||
YarnBetaFast: 32.0,
|
||||
YarnBetaSlow: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
// RoPEConfig holds configuration for Rotary Position Embedding
|
||||
type RoPEConfig struct {
|
||||
// Dim is the dimensionality for applying rotary embeddings
|
||||
Dim uint32
|
||||
|
||||
// Type specifies the RoPE implementation variant
|
||||
Type RopeType
|
||||
|
||||
// Base controls frequency decay for the embeddings
|
||||
Base float32
|
||||
|
||||
// Scale allows scaling the effective context length
|
||||
Scale float32
|
||||
|
||||
*YarnConfig
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
Dim(n int) int
|
||||
Stride(n int) int
|
||||
@ -141,7 +188,7 @@ type Tensor interface {
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
|
||||
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
|
@ -907,6 +907,8 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
// GGML RoPE types
|
||||
// These are the types used in the C implementation of RoPE
|
||||
const (
|
||||
ropeTypeNorm C.int = 0
|
||||
ropeTypeNeox C.int = 2
|
||||
@ -914,7 +916,8 @@ const (
|
||||
ropeTypeVision C.int = 24
|
||||
)
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
// RoPE applies Rotary Position Embeddings to the tensor
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{b: t.b}
|
||||
}
|
||||
@ -924,19 +927,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||
}
|
||||
|
||||
if config.YarnConfig == nil {
|
||||
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
|
||||
}
|
||||
|
||||
// Map Go RopeType to C implementation constants
|
||||
var ropeTypeC C.int
|
||||
switch config.Type {
|
||||
case ml.RopeTypeNormal:
|
||||
ropeTypeC = ropeTypeNorm
|
||||
case ml.RopeTypeNeox:
|
||||
ropeTypeC = ropeTypeNeox
|
||||
case ml.RopeTypeMRoPE:
|
||||
ropeTypeC = ropeTypeMrope
|
||||
case ml.RopeTypeVision:
|
||||
ropeTypeC = ropeTypeVision
|
||||
default:
|
||||
ropeTypeC = ropeTypeNorm
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_rope_ext(
|
||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
C.int(ropeType),
|
||||
131072, // YaRN n_ctx_train
|
||||
C.float(ropeBase),
|
||||
C.float(ropeScale),
|
||||
0., // YaRN ext_factor
|
||||
1., // YaRN attn_factor
|
||||
32., // YaRN beta_fast
|
||||
1., // YaRN beta_slow
|
||||
ctx.(*Context).ctx,
|
||||
dequant,
|
||||
positionIDs.(*Tensor).t,
|
||||
ropeFactors.(*Tensor).t,
|
||||
C.int(config.Dim),
|
||||
ropeTypeC,
|
||||
C.int(config.YarnCtxTrain),
|
||||
C.float(config.Base),
|
||||
C.float(config.Scale),
|
||||
C.float(config.YarnExtFactor),
|
||||
C.float(config.YarnAttnFactor),
|
||||
C.float(config.YarnBetaFast),
|
||||
C.float(config.YarnBetaSlow),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
@ -13,10 +13,11 @@ import (
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeBase, ropeScale float32
|
||||
eps float32
|
||||
attnLogitSoftcap float32
|
||||
finalLogitSoftcap float32
|
||||
largeModelScaling bool
|
||||
ropeConfig ml.RoPEConfig
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@ -55,10 +56,15 @@ func New(c ml.Config) (model.Model, error) {
|
||||
attnKeyLen: int(c.Uint("attention.key_length")),
|
||||
attnValLen: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base", 10000.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
||||
ropeConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.freq_base", 10000.0),
|
||||
Scale: c.Float("rope.freq_scale", 1.0),
|
||||
Dim: c.Uint("attention.key_length"),
|
||||
Type: ml.RopeTypeNormal,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -78,11 +84,10 @@ type SelfAttention struct {
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(2)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@ -92,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@ -122,7 +127,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, nil, m.ropeConfig), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
@ -13,9 +13,11 @@ import (
|
||||
type TextOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase, ropeGlobalBase float32
|
||||
eps float32
|
||||
largeModelScaling bool
|
||||
|
||||
ropeLocalConfig ml.RoPEConfig
|
||||
ropeGlobalConfig ml.RoPEConfig
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
@ -56,15 +58,27 @@ func newTextModel(c ml.Config) *TextModel {
|
||||
),
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
|
||||
ropeLocalConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.local.freq_base", 10000.0),
|
||||
Scale: c.Float("rope.freq_scale", 1.0),
|
||||
Dim: c.Uint("attention.key_length", 256),
|
||||
Type: ml.RopeTypeNeox,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
ropeGlobalConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.global.freq_base", 1000000.0),
|
||||
Scale: c.Float("rope.freq_scale", 1.0),
|
||||
Dim: c.Uint("attention.key_length", 256),
|
||||
Type: ml.RopeTypeNeox,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -86,17 +100,16 @@ type TextSelfAttention struct {
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(2)
|
||||
|
||||
ropeBase := opts.ropeLocalBase
|
||||
ropeConfig := opts.ropeLocalConfig
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = opts.ropeGlobalBase
|
||||
ropeConfig = opts.ropeGlobalConfig
|
||||
}
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, nil, ropeConfig)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@ -107,7 +120,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, nil, ropeConfig)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@ -120,12 +133,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.TextOptions.ropeLocalBase
|
||||
ropeConfig := m.ropeLocalConfig
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = m.TextOptions.ropeGlobalBase
|
||||
ropeConfig = m.ropeGlobalConfig
|
||||
}
|
||||
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, nil, ropeConfig), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
|
@ -14,8 +14,8 @@ import (
|
||||
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
eps float32
|
||||
ropeConfig ml.RoPEConfig
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@ -54,9 +54,13 @@ func New(c ml.Config) (model.Model, error) {
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
ropeConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.freq_base"),
|
||||
Scale: c.Float("rope.freq_scale", 1),
|
||||
Dim: c.Uint("rope.dimension_count"),
|
||||
Type: ml.RopeTypeNormal,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -76,15 +80,14 @@ type SelfAttention struct {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@ -97,7 +100,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
@ -20,15 +20,14 @@ type TextSelfAttention struct {
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@ -43,7 +42,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil
|
||||
}
|
||||
|
||||
return key, nil
|
||||
@ -198,8 +197,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
||||
|
||||
type TextModelOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
eps float32
|
||||
ropeConfig ml.RoPEConfig
|
||||
|
||||
crossAttentionLayers []uint32
|
||||
}
|
||||
@ -240,10 +239,14 @@ func newTextModel(c ml.Config) *TextModel {
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
|
||||
ropeConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.freq_base"),
|
||||
Scale: c.Float("rope.freq_scale", 1),
|
||||
Dim: c.Uint("rope.dimension_count"),
|
||||
Type: ml.RopeTypeNormal,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user