Compare commits
16 Commits
main
...
jmorganca/
Author | SHA1 | Date | |
---|---|---|---|
![]() |
cfeca27133 | ||
![]() |
4530661799 | ||
![]() |
8dd2a81f8c | ||
![]() |
caddb1e4cf | ||
![]() |
4d8dac8ffc | ||
![]() |
63e6509ec0 | ||
![]() |
6f34126dcc | ||
![]() |
ecc0ef468f | ||
![]() |
9b57238834 | ||
![]() |
3b4ad00a4b | ||
![]() |
9a12fd1067 | ||
![]() |
edac05387f | ||
![]() |
e65cf9dc94 | ||
![]() |
7e3c62f388 | ||
![]() |
a75703b2cc | ||
![]() |
c24e8860c1 |
@ -182,8 +182,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
|
|
||||||
var conv ModelConverter
|
var conv ModelConverter
|
||||||
switch p.Architectures[0] {
|
switch p.Architectures[0] {
|
||||||
case "LlamaForCausalLM", "MistralForCausalLM":
|
case "LlamaForCausalLM":
|
||||||
conv = &llamaModel{}
|
conv = &llamaModel{}
|
||||||
|
case "Mistral3ForConditionalGeneration":
|
||||||
|
conv = &mistral3Model{}
|
||||||
case "MixtralForCausalLM":
|
case "MixtralForCausalLM":
|
||||||
conv = &mixtralModel{}
|
conv = &mixtralModel{}
|
||||||
case "GemmaForCausalLM":
|
case "GemmaForCausalLM":
|
||||||
@ -246,5 +248,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// iterate through all ts and print the name
|
||||||
|
for _, t := range ts {
|
||||||
|
fmt.Print(t.Name(), "\n")
|
||||||
|
}
|
||||||
|
|
||||||
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
|
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
|
||||||
}
|
}
|
||||||
|
194
convert/convert_mistral.go
Normal file
194
convert/convert_mistral.go
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mistral3Model struct {
|
||||||
|
ModelParameters
|
||||||
|
ImageTokenIndex uint32 `json:"image_token_index"`
|
||||||
|
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||||
|
VisionFeatureLayer int32 `json:"vision_feature_layer"`
|
||||||
|
TextModel struct {
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
SlidingWindow *uint32 `json:"sliding_window"`
|
||||||
|
HiddenAct string `json:"hidden_act"`
|
||||||
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
VisionModel struct {
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
ImageSize uint32 `json:"image_size"`
|
||||||
|
NumChannels uint32 `json:"num_channels"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
HiddenAct string `json:"hidden_act"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
} `json:"vision_config"`
|
||||||
|
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||||
|
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "mistral3"
|
||||||
|
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||||
|
|
||||||
|
// Text configuration
|
||||||
|
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
|
||||||
|
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
|
||||||
|
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
|
||||||
|
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||||
|
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
|
||||||
|
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
|
||||||
|
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||||
|
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||||
|
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||||
|
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||||
|
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||||
|
|
||||||
|
// Vision configuration
|
||||||
|
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
|
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||||
|
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||||
|
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||||
|
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
|
||||||
|
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
|
||||||
|
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
|
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||||
|
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||||
|
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||||
|
|
||||||
|
// Multimodal configuration
|
||||||
|
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||||
|
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
|
||||||
|
|
||||||
|
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
|
||||||
|
|
||||||
|
if p.ProjectorHiddenAct != "" {
|
||||||
|
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
|
var out []ggml.Tensor
|
||||||
|
|
||||||
|
for _, t := range ts {
|
||||||
|
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
|
||||||
|
strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||||
|
t.SetRepacker(p.repack)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip certain vision model tensors that might need special handling
|
||||||
|
if strings.HasPrefix(t.Name(), "patch_merger.") || strings.HasPrefix(t.Name(), "pre_mm_projector_output_norm.") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"language_model.model.norm", "output_norm",
|
||||||
|
"language_model.model.", "",
|
||||||
|
"language_model.", "",
|
||||||
|
"layers", "blk",
|
||||||
|
"transformer.layers", "blk",
|
||||||
|
"vision_tower", "v",
|
||||||
|
"ln_pre", "encoder_norm",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
"embed_tokens", "token_embd",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"attention.q_proj", "attn_q",
|
||||||
|
"attention.k_proj", "attn_k",
|
||||||
|
"attention.v_proj", "attn_v",
|
||||||
|
"attention.o_proj", "attn_output",
|
||||||
|
"attention_norm", "attn_norm",
|
||||||
|
"feed_forward.gate_proj", "ffn_gate",
|
||||||
|
"feed_forward.down_proj", "ffn_down",
|
||||||
|
"feed_forward.up_proj", "ffn_up",
|
||||||
|
"patch_merger.merging_layer", "merger",
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
"ffn_norm", "ffn_norm",
|
||||||
|
"lm_head", "output",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
var dims []int
|
||||||
|
for _, dim := range shape {
|
||||||
|
dims = append(dims, int(dim))
|
||||||
|
}
|
||||||
|
|
||||||
|
var heads uint32
|
||||||
|
if strings.HasSuffix(name, "attn_q.weight") {
|
||||||
|
heads = p.TextModel.NumAttentionHeads
|
||||||
|
} else if strings.HasSuffix(name, "attn_k.weight") {
|
||||||
|
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.T(0, 2, 1, 3); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Reshape(dims...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Transpose(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := native.SelectF32(n, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var f32s []float32
|
||||||
|
for _, t := range ts {
|
||||||
|
f32s = append(f32s, t...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f32s, nil
|
||||||
|
}
|
@ -62,10 +62,7 @@ func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
|||||||
Pattern string
|
Pattern string
|
||||||
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
||||||
}{
|
}{
|
||||||
{"model-*-of-*.safetensors", parseSafetensors},
|
{"*.safetensors", parseSafetensors},
|
||||||
{"model.safetensors", parseSafetensors},
|
|
||||||
{"adapters.safetensors", parseSafetensors},
|
|
||||||
{"adapter_model.safetensors", parseSafetensors},
|
|
||||||
{"pytorch_model-*-of-*.bin", parseTorch},
|
{"pytorch_model-*-of-*.bin", parseTorch},
|
||||||
{"pytorch_model.bin", parseTorch},
|
{"pytorch_model.bin", parseTorch},
|
||||||
{"consolidated.*.pth", parseTorch},
|
{"consolidated.*.pth", parseTorch},
|
||||||
|
@ -144,6 +144,9 @@ type Tensor interface {
|
|||||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
|
||||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||||
|
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, ropeDim uint32, sections [4]int, ropeType uint32, base, scale float32) Tensor
|
||||||
|
|
||||||
|
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
|
||||||
Tanh(ctx Context) Tensor
|
Tanh(ctx Context) Tensor
|
||||||
GELU(ctx Context) Tensor
|
GELU(ctx Context) Tensor
|
||||||
|
@ -958,6 +958,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||||
|
if ropeFactors == nil {
|
||||||
|
ropeFactors = &Tensor{b: t.b}
|
||||||
|
}
|
||||||
|
|
||||||
|
dequant := t.t
|
||||||
|
if C.ggml_is_quantized(t.t._type) {
|
||||||
|
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_rope_multi(
|
||||||
|
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||||
|
C.int(ropeDim),
|
||||||
|
(*C.int)(unsafe.Pointer(§ions[0])),
|
||||||
|
C.int(ropeType),
|
||||||
|
131072, // YaRN n_ctx_train
|
||||||
|
C.float(ropeBase),
|
||||||
|
C.float(ropeScale),
|
||||||
|
0., // YaRN ext_factor
|
||||||
|
1., // YaRN attn_factor
|
||||||
|
32., // YaRN beta_fast
|
||||||
|
1., // YaRN beta_slow
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, weight.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
|
@ -2186,6 +2186,10 @@ static void ggml_metal_encode_node(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
|
if (ne00 != ne10) {
|
||||||
|
printf("mul_mat, ne00: %d, ne01: %d, ne02: %d, ne03: %d, ne10: %d, ne11: %d, ne12: %d, ne13: %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13);
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
|
@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TextOptions struct {
|
type TextConfig struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
attnKeyLen, attnValLen int
|
attnKeyLen, attnValLen int
|
||||||
eps, ropeScale float32
|
eps, ropeScale float32
|
||||||
@ -27,7 +27,7 @@ type TextModel struct {
|
|||||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
*TextOptions
|
*TextConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -55,7 +55,7 @@ func newTextModel(c ml.Config) *TextModel {
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
Layers: make([]TextLayer, numBlocks),
|
Layers: make([]TextLayer, numBlocks),
|
||||||
TextOptions: &TextOptions{
|
TextConfig: &TextConfig{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
@ -84,7 +84,7 @@ type TextSelfAttention struct {
|
|||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
ropeType := uint32(2)
|
ropeType := uint32(2)
|
||||||
|
|
||||||
@ -120,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
ropeBase := m.TextOptions.ropeLocalBase
|
ropeBase := m.TextConfig.ropeLocalBase
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
ropeBase = m.TextOptions.ropeGlobalBase
|
ropeBase = m.TextConfig.ropeGlobalBase
|
||||||
}
|
}
|
||||||
|
|
||||||
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
@ -134,7 +134,7 @@ type TextMLP struct {
|
|||||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
@ -148,7 +148,7 @@ type TextLayer struct {
|
|||||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||||
residual := hiddenState
|
residual := hiddenState
|
||||||
|
|
||||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
@ -173,7 +173,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||||
|
|
||||||
// set image embeddings
|
// set image embeddings
|
||||||
var except []int
|
var except []int
|
||||||
@ -206,7 +206,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
|||||||
lastLayerOutputs = outputs
|
lastLayerOutputs = outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
@ -13,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads, headDim int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
ropeDim uint32
|
||||||
}
|
}
|
||||||
@ -37,6 +37,8 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
|
|
||||||
m := Model{
|
m := Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
|
// TODO: need to set this in the conversion for mistral:
|
||||||
|
// tokenizer.ggml.pretokenizer = [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
@ -53,6 +55,7 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
@ -75,24 +78,36 @@ type SelfAttention struct {
|
|||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
|
||||||
ropeType := uint32(0)
|
ropeType := uint32(0)
|
||||||
|
// Get head dimension - use explicit value if available, otherwise calculate
|
||||||
|
headDim := opts.headDim
|
||||||
|
if headDim == 0 {
|
||||||
|
headDim = opts.hiddenSize / opts.numHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query projection and reshape
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
|
// Key projection and reshape
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
|
// Value projection and reshape
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
// Attention computation
|
||||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
|
||||||
|
|
||||||
|
// Reshape attention output for final projection
|
||||||
|
outputDim := headDim * opts.numHeads
|
||||||
|
kqv = kqv.Reshape(ctx, outputDim, batchSize)
|
||||||
|
|
||||||
|
// Apply output projection
|
||||||
return sa.Output.Forward(ctx, kqv)
|
return sa.Output.Forward(ctx, kqv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package pixtral
|
package mistral3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model/imageproc"
|
"github.com/ollama/ollama/model/imageproc"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,15 +21,14 @@ func getNumImageTokens(imageSize, patchSize image.Point) image.Point {
|
|||||||
|
|
||||||
func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point {
|
func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point {
|
||||||
b := img.Bounds()
|
b := img.Bounds()
|
||||||
le := float64(longestEdge)
|
ratio := math.Max(float64(b.Max.Y)/float64(longestEdge), float64(b.Max.X)/float64(longestEdge))
|
||||||
ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le)
|
|
||||||
|
|
||||||
newSize := img.Bounds().Max
|
newSize := img.Bounds().Max
|
||||||
|
|
||||||
if ratio > 1.0 {
|
if ratio > 1.0 {
|
||||||
newSize = image.Point{
|
newSize = image.Point{
|
||||||
int(math.Ceil(float64(b.Max.X) / ratio)),
|
int(math.Floor(float64(b.Max.X) / ratio)),
|
||||||
int(math.Ceil(float64(b.Max.Y) / ratio)),
|
int(math.Floor(float64(b.Max.Y) / ratio)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,3 +66,27 @@ func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
|
|||||||
opts := map[string]any{}
|
opts := map[string]any{}
|
||||||
return data, opts, nil
|
return data, opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ImageProcessor struct {
|
||||||
|
imageSize int
|
||||||
|
patchSize int
|
||||||
|
numChannels int
|
||||||
|
longestEdge int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newImageProcessor(c ml.Config) ImageProcessor {
|
||||||
|
return ImageProcessor{
|
||||||
|
imageSize: int(c.Uint("vision.image_size", 1540)),
|
||||||
|
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||||
|
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||||
|
longestEdge: int(c.Uint("vision.longest_edge", 1540)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
||||||
|
outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize})
|
||||||
|
newImage := imageproc.Composite(img)
|
||||||
|
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
||||||
|
data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
|
||||||
|
return data, nil
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package pixtral
|
package mistral3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
120
model/models/mistral3/model.go
Normal file
120
model/models/mistral3/model.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package mistral3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"image"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
*TextModel
|
||||||
|
*VisionModel `gguf:"v,vision"`
|
||||||
|
*MultiModalProjector `gguf:"mm"`
|
||||||
|
|
||||||
|
ImageProcessor
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement MultimodalProcessor interface
|
||||||
|
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||||
|
|
||||||
|
func New(c ml.Config) (model.Model, error) {
|
||||||
|
textModel, err := NewTextModel(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &Model{
|
||||||
|
TextModel: textModel,
|
||||||
|
VisionModel: newVisionModel(c),
|
||||||
|
ImageProcessor: newImageProcessor(c),
|
||||||
|
MultiModalProjector: newMultiModalProjector(c),
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||||
|
if len(m.VisionModel.Layers) == 0 {
|
||||||
|
return nil, model.ErrNoVisionModel
|
||||||
|
}
|
||||||
|
|
||||||
|
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s, err := m.ImageProcessor.ProcessImage(image)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create tensor from image data
|
||||||
|
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
||||||
|
m.ImageProcessor.imageSize,
|
||||||
|
1036, // TODO (jmorganca): this should be returned from ProcessImage
|
||||||
|
m.ImageProcessor.numChannels,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Println("pixelValues", "shape", pixelValues.Shape(), "data", ml.Dump(ctx, pixelValues))
|
||||||
|
|
||||||
|
// Forward pass through vision model
|
||||||
|
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||||
|
|
||||||
|
// fmt.Println("visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
||||||
|
|
||||||
|
// Project to text embedding space
|
||||||
|
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
|
||||||
|
|
||||||
|
// fmt.Println("visionOutputs after projector", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
||||||
|
|
||||||
|
return visionOutputs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
|
var result []input.Input
|
||||||
|
|
||||||
|
for _, inp := range inputs {
|
||||||
|
if inp.Multimodal == nil {
|
||||||
|
result = append(result, inp)
|
||||||
|
} else {
|
||||||
|
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
||||||
|
|
||||||
|
// Add special image tokens - using the imageTokenIndex from config
|
||||||
|
result = append(result, input.Input{Token: 10}) // [IMG]
|
||||||
|
result = append(result, input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}) // image data
|
||||||
|
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, inputMultimodal.Dim(1)-1)...) // [IMG] placeholders
|
||||||
|
result = append(result, input.Input{Token: 13}) // [IMG_END]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("mistral3", New)
|
||||||
|
}
|
171
model/models/mistral3/model_text.go
Normal file
171
model/models/mistral3/model_text.go
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
package mistral3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TextOptions struct {
|
||||||
|
hiddenSize, numHeads, numKVHeads, headDim int
|
||||||
|
eps, ropeBase, ropeScale float32
|
||||||
|
ropeDim uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextModel struct {
|
||||||
|
model.Base
|
||||||
|
model.BytePairEncoding
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
Layers []Layer `gguf:"blk"`
|
||||||
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
|
*TextOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
type SelfAttention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||||
|
batchSize := hiddenState.Dim(1)
|
||||||
|
ropeType := uint32(0)
|
||||||
|
// Get head dimension - use explicit value if available, otherwise calculate
|
||||||
|
headDim := opts.headDim
|
||||||
|
if headDim == 0 {
|
||||||
|
headDim = opts.hiddenSize / opts.numHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query projection and reshape
|
||||||
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
|
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
|
// Key projection and reshape
|
||||||
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
|
// Value projection and reshape
|
||||||
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
// Attention computation
|
||||||
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
|
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||||
|
|
||||||
|
// Reshape attention output for final projection
|
||||||
|
outputDim := headDim * opts.numHeads
|
||||||
|
kqv = kqv.Reshape(ctx, outputDim, batchSize)
|
||||||
|
|
||||||
|
// Apply output projection
|
||||||
|
return sa.Output.Forward(ctx, kqv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP struct {
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Layer struct {
|
||||||
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||||
|
SelfAttention *SelfAttention
|
||||||
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||||
|
MLP *MLP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||||
|
residual := hiddenState
|
||||||
|
|
||||||
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||||
|
|
||||||
|
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||||
|
// we need logits for.
|
||||||
|
if outputs != nil {
|
||||||
|
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||||
|
residual = residual.Rows(ctx, outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = hiddenState.Add(ctx, residual)
|
||||||
|
residual = hiddenState
|
||||||
|
|
||||||
|
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||||
|
return hiddenState.Add(ctx, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
|
// Process text inputs
|
||||||
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||||
|
|
||||||
|
// Process through text transformer layers
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
cache.SetLayer(i)
|
||||||
|
|
||||||
|
var lastLayerOutputs ml.Tensor
|
||||||
|
if i == len(m.Layers)-1 {
|
||||||
|
lastLayerOutputs = outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
return m.Output.Forward(ctx, hiddenState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextModel(c ml.Config) (*TextModel, error) {
|
||||||
|
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
||||||
|
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
||||||
|
}
|
||||||
|
|
||||||
|
textModel := &TextModel{
|
||||||
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
|
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
|
TextOptions: &TextOptions{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
|
ropeDim: c.Uint("rope.dimension_count"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return textModel, nil
|
||||||
|
}
|
201
model/models/mistral3/model_vision.go
Normal file
201
model/models/mistral3/model_vision.go
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
package mistral3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
var batchSize int = 1
|
||||||
|
|
||||||
|
type PatchMerger struct {
|
||||||
|
MergingLayer *nn.Linear `gguf:"merging_layer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||||
|
// TODO: pass these in
|
||||||
|
w := 110
|
||||||
|
h := 74
|
||||||
|
// tokensPerImage := w * h
|
||||||
|
d := visionOutputs.Dim(0)
|
||||||
|
|
||||||
|
// TODO: handle multiple images, this currently assumes one
|
||||||
|
// fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
||||||
|
|
||||||
|
// Reshape to [h, w, hidden_size]
|
||||||
|
imageGrid := visionOutputs.Reshape(ctx, h, w, d)
|
||||||
|
// fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid))
|
||||||
|
|
||||||
|
// TODO: load from config
|
||||||
|
spatialMergeSize := 2
|
||||||
|
kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1)
|
||||||
|
// fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel))
|
||||||
|
|
||||||
|
patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1)
|
||||||
|
// fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches))
|
||||||
|
|
||||||
|
// fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2))
|
||||||
|
reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2))
|
||||||
|
// fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped))
|
||||||
|
|
||||||
|
return pm.MergingLayer.Forward(ctx, reshaped)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MultiModalProjector struct {
|
||||||
|
Norm *nn.RMSNorm `gguf:"norm"`
|
||||||
|
Linear1 *nn.Linear `gguf:"linear_1"`
|
||||||
|
Linear2 *nn.Linear `gguf:"linear_2"`
|
||||||
|
PatchMerger *PatchMerger `gguf:"patch_merger"`
|
||||||
|
|
||||||
|
spatialMergeSize int
|
||||||
|
imageTokenIndex int
|
||||||
|
hasBias bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||||
|
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
|
||||||
|
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs)
|
||||||
|
visionOutputs = p.Linear1.Forward(ctx, visionOutputs)
|
||||||
|
visionOutputs = visionOutputs.GELU(ctx)
|
||||||
|
return p.Linear2.Forward(ctx, visionOutputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMultiModalProjector(c ml.Config) *MultiModalProjector {
|
||||||
|
return &MultiModalProjector{
|
||||||
|
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
|
||||||
|
imageTokenIndex: int(c.Uint("image_token_index", 10)),
|
||||||
|
hasBias: c.Bool("mm.projector_bias", false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionSelfAttention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
|
|
||||||
|
q = q.Reshape(ctx, opts.headDim, opts.numHeads, q.Dim(1), batchSize)
|
||||||
|
k = k.Reshape(ctx, opts.headDim, opts.numHeads, k.Dim(1), batchSize)
|
||||||
|
v = v.Reshape(ctx, opts.headDim, opts.numHeads, v.Dim(1), batchSize)
|
||||||
|
|
||||||
|
ropeType := uint32(24) // 2d vision rope
|
||||||
|
q = q.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
k = k.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(opts.headDim)), nil)
|
||||||
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||||
|
return sa.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionMLP struct {
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionEncoderLayer struct {
|
||||||
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||||
|
SelfAttention *VisionSelfAttention
|
||||||
|
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||||
|
MLP *VisionMLP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
|
residual := hiddenState
|
||||||
|
|
||||||
|
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
fmt.Println("after attention norm", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
|
||||||
|
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
|
||||||
|
hiddenState = hiddenState.Add(ctx, residual)
|
||||||
|
residual = hiddenState
|
||||||
|
|
||||||
|
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||||
|
return hiddenState.Add(ctx, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionModelOptions struct {
|
||||||
|
hiddenSize int
|
||||||
|
numHeads int
|
||||||
|
headDim int
|
||||||
|
intermediateSize int
|
||||||
|
imageSize int
|
||||||
|
patchSize int
|
||||||
|
numChannels int
|
||||||
|
eps float32
|
||||||
|
ropeBase float32
|
||||||
|
ropeScale float32
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionModel struct {
|
||||||
|
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
|
||||||
|
EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"`
|
||||||
|
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||||
|
|
||||||
|
*VisionModelOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||||
|
numPatchesH := pixelValues.Dim(1) / m.patchSize
|
||||||
|
numPatchesW := pixelValues.Dim(0) / m.patchSize
|
||||||
|
numPatches := numPatchesH * numPatchesW
|
||||||
|
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||||
|
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||||
|
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
|
hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps)
|
||||||
|
|
||||||
|
totalPositions := numPatchesH * numPatchesW
|
||||||
|
positions := make([]int32, totalPositions*4)
|
||||||
|
|
||||||
|
for h := 0; h < numPatchesH; h++ {
|
||||||
|
for w := 0; w < numPatchesW; w++ {
|
||||||
|
index := h*numPatchesW + w
|
||||||
|
positions[totalPositions+index] = int32(h)
|
||||||
|
positions[totalPositions*2+index] = int32(w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, layer := range m.Layers {
|
||||||
|
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Println("after layers", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
|
||||||
|
|
||||||
|
return hiddenState
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVisionModel(c ml.Config) *VisionModel {
|
||||||
|
return &VisionModel{
|
||||||
|
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
|
||||||
|
VisionModelOptions: &VisionModelOptions{
|
||||||
|
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
|
||||||
|
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
||||||
|
headDim: int(c.Uint("vision.attention.key_length", 64)),
|
||||||
|
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
|
||||||
|
imageSize: int(c.Uint("vision.image_size", 1540)),
|
||||||
|
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||||
|
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||||
|
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
|
||||||
|
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
|
||||||
|
ropeScale: c.Float("vision.rope.freq_scale", 1.0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
@ -4,5 +4,6 @@ import (
|
|||||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||||
_ "github.com/ollama/ollama/model/models/llama"
|
_ "github.com/ollama/ollama/model/models/llama"
|
||||||
|
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||||
_ "github.com/ollama/ollama/model/models/mllama"
|
_ "github.com/ollama/ollama/model/models/mllama"
|
||||||
)
|
)
|
||||||
|
@ -263,6 +263,10 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||||
merges[pair.b].runes = nil
|
merges[pair.b].runes = nil
|
||||||
|
|
||||||
|
@ -209,6 +209,322 @@ func TestLlama(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tekken loads the Tekken tokenizer for testing
|
||||||
|
func tekken(t testing.TB) TextProcessor {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Load tokenizer config from mistral-small
|
||||||
|
tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
|
||||||
|
configFile, err := os.Open(tokenizerConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer configFile.Close()
|
||||||
|
|
||||||
|
var config struct {
|
||||||
|
AddBosToken bool `json:"add_bos_token"`
|
||||||
|
AddEosToken bool `json:"add_eos_token"`
|
||||||
|
BosToken string `json:"bos_token"`
|
||||||
|
EosToken string `json:"eos_token"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load tokenizer.json which contains the vocabulary and other settings
|
||||||
|
tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
|
||||||
|
tokenizerFile, err := os.Open(tokenizerJsonPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tokenizerFile.Close()
|
||||||
|
|
||||||
|
var tokenizerData struct {
|
||||||
|
Model struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Vocab map[string]int32 `json:"vocab"`
|
||||||
|
Merges []string `json:"merges"`
|
||||||
|
} `json:"model"`
|
||||||
|
AddedTokens []struct {
|
||||||
|
Id int32 `json:"id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Special bool `json:"special"`
|
||||||
|
} `json:"added_tokens"`
|
||||||
|
PreTokenizer struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Pretokenizers []struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Pattern struct {
|
||||||
|
String string `json:"String"`
|
||||||
|
} `json:"pattern"`
|
||||||
|
Behavior string `json:"behavior"`
|
||||||
|
} `json:"pretokenizers"`
|
||||||
|
} `json:"pre_tokenizer"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the pattern from pre_tokenizer if available
|
||||||
|
var pattern string
|
||||||
|
if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
|
||||||
|
pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine regular vocab and added tokens
|
||||||
|
vocab := tokenizerData.Model.Vocab
|
||||||
|
|
||||||
|
// Add special tokens from added_tokens
|
||||||
|
for _, token := range tokenizerData.AddedTokens {
|
||||||
|
vocab[token.Content] = token.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create vocabulary arrays
|
||||||
|
maxId := int32(-1)
|
||||||
|
for _, id := range vocab {
|
||||||
|
if id > maxId {
|
||||||
|
maxId = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vocabSize := int(maxId + 1)
|
||||||
|
types := make([]uint32, vocabSize)
|
||||||
|
tokens := make([]string, vocabSize)
|
||||||
|
scores := make([]float32, vocabSize)
|
||||||
|
|
||||||
|
for token, id := range vocab {
|
||||||
|
tokens[id] = token
|
||||||
|
types[id] = TOKEN_TYPE_NORMAL
|
||||||
|
|
||||||
|
// Assign appropriate token types for special tokens
|
||||||
|
if token == "<s>" {
|
||||||
|
types[id] = TOKEN_TYPE_CONTROL
|
||||||
|
} else if token == "</s>" {
|
||||||
|
types[id] = TOKEN_TYPE_CONTROL
|
||||||
|
} else if token == "[INST]" || token == "[/INST]" {
|
||||||
|
types[id] = TOKEN_TYPE_CONTROL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// In Tekken, we don't need to load merges separately as they're part of the model
|
||||||
|
var merges []string
|
||||||
|
|
||||||
|
// Create vocabulary object
|
||||||
|
vocabObj := &Vocabulary{
|
||||||
|
Values: tokens,
|
||||||
|
Types: types,
|
||||||
|
Scores: scores,
|
||||||
|
Merges: merges,
|
||||||
|
BOS: vocab[config.BosToken],
|
||||||
|
EOS: vocab[config.EosToken],
|
||||||
|
AddBOS: config.AddBosToken,
|
||||||
|
AddEOS: config.AddEosToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use pattern from tokenizer.json if available
|
||||||
|
if pattern != "" {
|
||||||
|
// Ensure pattern has proper escaping for Go regexp
|
||||||
|
pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
|
||||||
|
return NewBytePairEncoding(pattern, vocabObj)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback pattern if not found
|
||||||
|
return NewBytePairEncoding(
|
||||||
|
`\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
|
||||||
|
vocabObj,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTekken(t *testing.T) {
|
||||||
|
// Skip if the test data isn't available
|
||||||
|
if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
|
||||||
|
t.Skip("Mistral-small test data not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer := tekken(t)
|
||||||
|
|
||||||
|
t.Run("whitespace_handling", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// The key difference from SentencePiece is that Tekken doesn't prepend whitespace
|
||||||
|
cases := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{" hello", " hello"},
|
||||||
|
{"hello ", "hello "},
|
||||||
|
{"hello world", "hello world"},
|
||||||
|
{" hello world ", " hello world "},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
ids, err := tokenizer.Encode(tc.input, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := tokenizer.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded != tc.expected {
|
||||||
|
t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("chat_templates", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Test the Tekken chat template format which doesn't have spaces after special tokens
|
||||||
|
templates := []struct {
|
||||||
|
input string
|
||||||
|
expectSpace bool // whether we expect a space after special tokens
|
||||||
|
}{
|
||||||
|
{"<s>[INST]user message[/INST]", false},
|
||||||
|
{"<s>[INST] user message[/INST]", true},
|
||||||
|
{"<s>[INST]user message [/INST]", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range templates {
|
||||||
|
ids, err := tokenizer.Encode(tc.input, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := tokenizer.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if there's a space after special tokens
|
||||||
|
hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
|
||||||
|
|
||||||
|
if hasSpaceAfterINST != tc.expectSpace {
|
||||||
|
t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
|
||||||
|
hasSpaceAfterINST, tc.expectSpace, tc.input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("special_tokens", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Test how Tekken handles special tokens
|
||||||
|
cases := []struct {
|
||||||
|
input string
|
||||||
|
expected []string // We'll check if these tokens are in the decoded output
|
||||||
|
}{
|
||||||
|
{"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
|
||||||
|
{"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
|
||||||
|
{"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
ids, err := tokenizer.Encode(tc.input, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := tokenizer.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range tc.expected {
|
||||||
|
if !strings.Contains(decoded, expected) {
|
||||||
|
t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vocabulary_coverage", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Tekken has a larger vocabulary, so test coverage of various token types
|
||||||
|
samples := []string{
|
||||||
|
"Hello world!",
|
||||||
|
"This is a test of the Tekken tokenizer.",
|
||||||
|
"It has a considerably larger vocabulary size.",
|
||||||
|
"Special characters: !@#$%^&*()",
|
||||||
|
"Numbers: 1234567890",
|
||||||
|
"Multiple languages: こんにちは 你好 안녕하세요",
|
||||||
|
"Code snippets: def function(): return True",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sample := range samples {
|
||||||
|
ids, err := tokenizer.Encode(sample, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to encode %q: %v", sample, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := tokenizer.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to decode tokens for %q: %v", sample, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded != sample {
|
||||||
|
t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("splitting_behavior", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Test the splitting behavior which might differ from SentencePiece
|
||||||
|
cases := map[string][]string{
|
||||||
|
"Hello World!": {"Hello", " World", "!"},
|
||||||
|
"user message": {"user", " message"},
|
||||||
|
"[INST]hello": {"[INST]", "hello"},
|
||||||
|
"hello[/INST]": {"hello", "[/INST]"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for s, want := range cases {
|
||||||
|
got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
|
t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("full_chat_sequence", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Test a complete chat sequence with Tekken's format
|
||||||
|
chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
|
||||||
|
|
||||||
|
ids, err := tokenizer.Encode(chatSequence, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to encode chat sequence: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := tokenizer.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to decode chat sequence tokens: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// In Tekken, the whitespace shouldn't be added after special tokens
|
||||||
|
if strings.Contains(decoded, "[INST] ") {
|
||||||
|
t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(decoded, "[/INST] ") {
|
||||||
|
t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||||
tokenizer := llama(b)
|
tokenizer := llama(b)
|
||||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||||
|
1217945
model/testdata/mistral-small/tokenizer.json
vendored
Normal file
1217945
model/testdata/mistral-small/tokenizer.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9020
model/testdata/mistral-small/tokenizer_config.json
vendored
Normal file
9020
model/testdata/mistral-small/tokenizer_config.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@ -211,16 +211,10 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var files []string
|
var files []string
|
||||||
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||||
// safetensors files might be unresolved git lfs references; skip if they are
|
// safetensors files might be unresolved git lfs references; skip if they are
|
||||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||||
files = append(files, st...)
|
files = append(files, st...)
|
||||||
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
|
|
||||||
// covers adapters.safetensors
|
|
||||||
files = append(files, st...)
|
|
||||||
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
|
|
||||||
// covers adapter_model.safetensors
|
|
||||||
files = append(files, st...)
|
|
||||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||||
|
@ -182,6 +182,10 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, t := range tokens {
|
||||||
|
decoded, _ := s.model.(model.TextProcessor).Decode([]int32{t})
|
||||||
|
fmt.Println("token", t, "decoded", decoded)
|
||||||
|
}
|
||||||
for _, t := range tokens {
|
for _, t := range tokens {
|
||||||
inputs = append(inputs, input.Input{Token: t})
|
inputs = append(inputs, input.Input{Token: t})
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user