convert: mistral-3.1-2503 text component

This commit is contained in:
Bruce MacDonald 2025-03-20 10:58:23 -07:00
parent 434f793075
commit fe796cfc75
3 changed files with 64 additions and 25 deletions

View File

@ -184,7 +184,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
switch p.Architectures[0] { switch p.Architectures[0] {
case "LlamaForCausalLM": case "LlamaForCausalLM":
conv = &llamaModel{} conv = &llamaModel{}
case "MistralForCausalLM": case "Mistral3ForConditionalGeneration":
conv = &mistralModel{} conv = &mistralModel{}
case "MixtralForCausalLM": case "MixtralForCausalLM":
conv = &mixtralModel{} conv = &mixtralModel{}

View File

@ -13,15 +13,17 @@ import (
type mistralModel struct { type mistralModel struct {
ModelParameters ModelParameters
NumHiddenLayers uint32 `json:"num_hidden_layers"` TextModel struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"` MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
IntermediateSize uint32 `json:"intermediate_size"` HiddenSize uint32 `json:"hidden_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"` IntermediateSize uint32 `json:"intermediate_size"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"` NumAttentionHeads uint32 `json:"num_attention_heads"`
RopeTheta float32 `json:"rope_theta"` NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"` RopeTheta float32 `json:"rope_theta"`
HeadDim uint32 `json:"head_dim"` RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
} `json:"text_config"`
} }
func (p *mistralModel) KV(t *Tokenizer) ggml.KV { func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
@ -29,17 +31,17 @@ func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
kv["general.architecture"] = "mistral" kv["general.architecture"] = "mistral"
kv["mistral.vocab_size"] = p.VocabSize kv["mistral.vocab_size"] = p.VocabSize
kv["mistral.block_count"] = p.NumHiddenLayers kv["mistral.block_count"] = p.TextModel.NumHiddenLayers
kv["mistral.context_length"] = p.MaxPositionEmbeddings kv["mistral.context_length"] = p.TextModel.MaxPositionEmbeddings
kv["mistral.embedding_length"] = cmp.Or(p.HiddenSize) kv["mistral.embedding_length"] = p.TextModel.HiddenSize
kv["mistral.feed_forward_length"] = cmp.Or(p.IntermediateSize) kv["mistral.feed_forward_length"] = p.TextModel.IntermediateSize
kv["mistral.attention.head_count"] = cmp.Or(p.NumAttentionHeads) kv["mistral.attention.head_count"] = p.TextModel.NumAttentionHeads
kv["mistral.rope.dimension_count"] = p.HiddenSize / p.NumHiddenLayers kv["mistral.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral.rope.freq_base"] = p.RopeTheta kv["mistral.rope.freq_base"] = p.TextModel.RopeTheta
kv["mistral.attention.head_count_kv"] = p.NumKeyValueHeads kv["mistral.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
kv["mistral.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS kv["mistral.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral.attention.key_length"] = p.HeadDim kv["mistral.attention.key_length"] = p.TextModel.HeadDim
kv["mistral.attention.value_length"] = p.HeadDim kv["mistral.attention.value_length"] = p.TextModel.HeadDim
return kv return kv
} }
@ -86,6 +88,43 @@ func (p *mistralModel) Replacements() []string {
"mlp.down_proj", "ffn_down", "mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate", "mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up", "mlp.up_proj", "ffn_up",
// Language model replacements
"language_model.model.embed_tokens", "token_embd",
"language_model.model.layers", "blk",
"language_model.model.layers.*.input_layernorm", "attn_norm",
"language_model.model.layers.*.self_attn.q_proj", "attn_q",
"language_model.model.layers.*.self_attn.k_proj", "attn_k",
"language_model.model.layers.*.self_attn.v_proj", "attn_v",
"language_model.model.layers.*.self_attn.o_proj", "attn_output",
"language_model.model.layers.*.mlp.gate_proj", "ffn_gate",
"language_model.model.layers.*.mlp.down_proj", "ffn_down",
"language_model.model.layers.*.mlp.up_proj", "ffn_up",
"language_model.model.layers.*.post_attention_layernorm", "ffn_norm",
"language_model.lm_head", "output",
"language_model.model.norm", "output_norm",
// Vision model replacements - map to shorter prefixes
"vision_tower", "v",
"multi_modal_projector", "mm",
// Vision transformer blocks - these should be updated accordingly
"vision_tower.transformer.layers", "v.blk",
"vision_tower.transformer.layers.*.attention_norm", "v.attn_norm",
"vision_tower.transformer.layers.*.attention.q_proj", "v.attn_q",
"vision_tower.transformer.layers.*.attention.k_proj", "v.attn_k",
"vision_tower.transformer.layers.*.attention.v_proj", "v.attn_v",
"vision_tower.transformer.layers.*.attention.o_proj", "v.attn_output",
"vision_tower.transformer.layers.*.feed_forward.gate_proj", "v.ffn_gate",
"vision_tower.transformer.layers.*.feed_forward.down_proj", "v.ffn_down",
"vision_tower.transformer.layers.*.feed_forward.up_proj", "v.ffn_up",
"vision_tower.transformer.layers.*.ffn_norm", "v.ffn_norm",
"vision_tower.ln_pre", "v.encoder_norm",
"vision_tower.patch_conv", "v.patch_conv",
// Multimodal projector components
"multi_modal_projector.patch_merger", "mm.patch_merger",
"multi_modal_projector.norm", "mm.norm",
} }
} }
@ -97,9 +136,9 @@ func (p *mistralModel) repack(name string, data []float32, shape []uint64) ([]fl
var heads uint32 var heads uint32
if strings.HasSuffix(name, "attn_q.weight") { if strings.HasSuffix(name, "attn_q.weight") {
heads = p.NumAttentionHeads heads = p.TextModel.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") { } else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
} else { } else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name) return nil, fmt.Errorf("unknown tensor for repack: %s", name)
} }

View File

@ -42,9 +42,9 @@ func New(c ml.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
}, },
), ),