minimal convert

This commit is contained in:
Bruce MacDonald 2025-03-19 16:56:52 -07:00 committed by jmorganca
parent 7e3c62f388
commit e65cf9dc94
3 changed files with 32 additions and 111 deletions

View File

@ -248,5 +248,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
return err
}
// iterate through all ts and print the name
for _, t := range ts {
fmt.Print(t.Name(), "\n")
}
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
}

View File

@ -3,7 +3,6 @@ package convert
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/pdevine/tensor"
@ -14,34 +13,15 @@ import (
type mistralModel struct {
ModelParameters
NLayers uint32 `json:"n_layers"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NLayer uint32 `json:"n_layer"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NCtx uint32 `json:"n_ctx"`
HiddenSize uint32 `json:"hidden_size"`
NEmbd uint32 `json:"n_embd"`
IntermediateSize uint32 `json:"intermediate_size"`
NInner uint32 `json:"n_inner"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NHead uint32 `json:"n_head"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling struct {
Type string `json:"type"`
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
factors ropeFactor
} `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"`
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
HeadDim uint32 `json:"head_dim"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
}
func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
@ -49,69 +29,17 @@ func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
kv["general.architecture"] = "mistral"
kv["mistral.vocab_size"] = p.VocabSize
kv["mistral.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
if contextLength := cmp.Or(p.MaxPositionEmbeddings, p.NCtx); contextLength > 0 {
kv["mistral.context_length"] = contextLength
}
if embeddingLength := cmp.Or(p.HiddenSize, p.NEmbd); embeddingLength > 0 {
kv["mistral.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
}
if feedForwardLength := cmp.Or(p.IntermediateSize, p.NInner); feedForwardLength > 0 {
kv["mistral.feed_forward_length"] = cmp.Or(p.IntermediateSize, p.NInner)
}
kv["mistral.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
kv["mistral.rope.dimension_count"] = p.HiddenSize / cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
if p.RopeTheta > 0 {
kv["mistral.rope.freq_base"] = p.RopeTheta
}
if p.RopeScaling.Type == "linear" {
kv["mistral.rope.scaling.type"] = p.RopeScaling.Type
kv["mistral.rope.scaling.factor"] = p.RopeScaling.Factor
} else if p.RopeScaling.RopeType == "llama3" {
dim := p.HiddenSize / p.NumAttentionHeads
for i := uint32(0); i < dim; i += 2 {
factor := cmp.Or(p.RopeScaling.Factor, 8.0)
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
lambdaLow := float32(original) / factorLow
lambdaHigh := float32(original) / factorHigh
lambda := 2 * math.Pi * math.Pow(float64(p.RopeTheta), float64(i)/float64(dim))
if lambda < float64(lambdaHigh) {
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0)
} else if lambda > float64(lambdaLow) {
p.RopeScaling.factors = append(p.RopeScaling.factors, factor)
} else {
smooth := (float32(original)/float32(lambda) - factorLow) / (factorHigh - factorLow)
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0/((1-smooth)/factor+smooth))
}
}
}
if p.NumKeyValueHeads > 0 {
kv["mistral.attention.head_count_kv"] = p.NumKeyValueHeads
}
if p.RMSNormEPS > 0 {
kv["mistral.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
}
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon, p.NormEpsilon); layerNormEpsilon > 0 {
kv["mistral.attention.layer_norm_epsilon"] = layerNormEpsilon
}
if p.HeadDim > 0 {
kv["mistral.attention.key_length"] = p.HeadDim
kv["mistral.attention.value_length"] = p.HeadDim
}
kv["mistral.block_count"] = p.NumHiddenLayers
kv["mistral.context_length"] = p.MaxPositionEmbeddings
kv["mistral.embedding_length"] = cmp.Or(p.HiddenSize)
kv["mistral.feed_forward_length"] = cmp.Or(p.IntermediateSize)
kv["mistral.attention.head_count"] = cmp.Or(p.NumAttentionHeads)
kv["mistral.rope.dimension_count"] = p.HiddenSize / p.NumHiddenLayers
kv["mistral.rope.freq_base"] = p.RopeTheta
kv["mistral.attention.head_count_kv"] = p.NumKeyValueHeads
kv["mistral.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["mistral.attention.key_length"] = p.HeadDim
kv["mistral.attention.value_length"] = p.HeadDim
return kv
}
@ -119,15 +47,6 @@ func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
func (p *mistralModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
if p.RopeScaling.factors != nil {
out = append(out, ggml.Tensor{
Name: "rope_freqs.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
WriterTo: p.RopeScaling.factors,
})
}
for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
strings.HasSuffix(t.Name(), "attn_k.weight") {
@ -154,19 +73,19 @@ func (p *mistralModel) Tensors(ts []Tensor) []ggml.Tensor {
func (p *mistralModel) Replacements() []string {
return []string{
"tok_embeddings", "token_embd",
"norm", "output_norm",
"layers", "blk",
"attention_norm", "attn_norm",
"attention.wq", "attn_q",
"attention.wk", "attn_k",
"attention.wv", "attn_v",
"attention.wo", "attn_output",
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",
"ffn_norm", "ffn_norm",
"output", "output",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"lm_head", "output",
"model.embed_tokens.weight", "token_embd.weight",
"model.norm.weight", "output_norm.weight",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
}
}

View File

@ -37,10 +37,7 @@ func New(c ml.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
// TODO: need to set this in the conversion for mistral:
// tokenizer.ggml.pretokenizer = [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
// c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),