wip
This commit is contained in:
parent
c24e8860c1
commit
a75703b2cc
@ -182,8 +182,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
|
|
||||||
var conv ModelConverter
|
var conv ModelConverter
|
||||||
switch p.Architectures[0] {
|
switch p.Architectures[0] {
|
||||||
case "LlamaForCausalLM", "MistralForCausalLM":
|
case "LlamaForCausalLM":
|
||||||
conv = &llamaModel{}
|
conv = &llamaModel{}
|
||||||
|
case "MistralForCausalLM":
|
||||||
|
conv = &mistralModel{}
|
||||||
case "MixtralForCausalLM":
|
case "MixtralForCausalLM":
|
||||||
conv = &mixtralModel{}
|
conv = &mixtralModel{}
|
||||||
case "GemmaForCausalLM":
|
case "GemmaForCausalLM":
|
||||||
|
209
convert/convert_mistral.go
Normal file
209
convert/convert_mistral.go
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mistralModel struct {
|
||||||
|
ModelParameters
|
||||||
|
NLayers uint32 `json:"n_layers"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
NLayer uint32 `json:"n_layer"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
NCtx uint32 `json:"n_ctx"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NEmbd uint32 `json:"n_embd"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NInner uint32 `json:"n_inner"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NHead uint32 `json:"n_head"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeScaling struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
RopeType string `json:"rope_type"`
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
LowFrequencyFactor float32 `json:"low_freq_factor"`
|
||||||
|
HighFrequencyFactor float32 `json:"high_freq_factor"`
|
||||||
|
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
|
||||||
|
|
||||||
|
factors ropeFactor
|
||||||
|
} `json:"rope_scaling"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
|
NormEpsilon float32 `json:"norm_epsilon"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "mistral"
|
||||||
|
kv["mistral.vocab_size"] = p.VocabSize
|
||||||
|
|
||||||
|
kv["mistral.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
||||||
|
|
||||||
|
if contextLength := cmp.Or(p.MaxPositionEmbeddings, p.NCtx); contextLength > 0 {
|
||||||
|
kv["mistral.context_length"] = contextLength
|
||||||
|
}
|
||||||
|
|
||||||
|
if embeddingLength := cmp.Or(p.HiddenSize, p.NEmbd); embeddingLength > 0 {
|
||||||
|
kv["mistral.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
|
||||||
|
}
|
||||||
|
|
||||||
|
if feedForwardLength := cmp.Or(p.IntermediateSize, p.NInner); feedForwardLength > 0 {
|
||||||
|
kv["mistral.feed_forward_length"] = cmp.Or(p.IntermediateSize, p.NInner)
|
||||||
|
}
|
||||||
|
|
||||||
|
if headCount := cmp.Or(p.NumAttentionHeads, p.NHead); headCount > 0 {
|
||||||
|
kv["mistral.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
|
||||||
|
kv["mistral.rope.dimension_count"] = p.HiddenSize / headCount
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeTheta > 0 {
|
||||||
|
kv["mistral.rope.freq_base"] = p.RopeTheta
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeScaling.Type == "linear" {
|
||||||
|
kv["mistral.rope.scaling.type"] = p.RopeScaling.Type
|
||||||
|
kv["mistral.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||||
|
} else if p.RopeScaling.RopeType == "llama3" {
|
||||||
|
dim := p.HiddenSize / p.NumAttentionHeads
|
||||||
|
for i := uint32(0); i < dim; i += 2 {
|
||||||
|
factor := cmp.Or(p.RopeScaling.Factor, 8.0)
|
||||||
|
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
|
||||||
|
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
|
||||||
|
|
||||||
|
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
|
||||||
|
lambdaLow := float32(original) / factorLow
|
||||||
|
lambdaHigh := float32(original) / factorHigh
|
||||||
|
|
||||||
|
lambda := 2 * math.Pi * math.Pow(float64(p.RopeTheta), float64(i)/float64(dim))
|
||||||
|
if lambda < float64(lambdaHigh) {
|
||||||
|
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0)
|
||||||
|
} else if lambda > float64(lambdaLow) {
|
||||||
|
p.RopeScaling.factors = append(p.RopeScaling.factors, factor)
|
||||||
|
} else {
|
||||||
|
smooth := (float32(original)/float32(lambda) - factorLow) / (factorHigh - factorLow)
|
||||||
|
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0/((1-smooth)/factor+smooth))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.NumKeyValueHeads > 0 {
|
||||||
|
kv["mistral.attention.head_count_kv"] = p.NumKeyValueHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RMSNormEPS > 0 {
|
||||||
|
kv["mistral.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
}
|
||||||
|
|
||||||
|
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon, p.NormEpsilon); layerNormEpsilon > 0 {
|
||||||
|
kv["mistral.attention.layer_norm_epsilon"] = layerNormEpsilon
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.HeadDim > 0 {
|
||||||
|
kv["mistral.attention.key_length"] = p.HeadDim
|
||||||
|
kv["mistral.attention.value_length"] = p.HeadDim
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistralModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
|
var out []ggml.Tensor
|
||||||
|
|
||||||
|
if p.RopeScaling.factors != nil {
|
||||||
|
out = append(out, ggml.Tensor{
|
||||||
|
Name: "rope_freqs.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
|
||||||
|
WriterTo: p.RopeScaling.factors,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range ts {
|
||||||
|
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
|
||||||
|
strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||||
|
t.SetRepacker(p.repack)
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistralModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"tok_embeddings.weight", "token_embd",
|
||||||
|
"norm", "output_norm",
|
||||||
|
"layers", "blk",
|
||||||
|
"attention_norm", "attn_norm",
|
||||||
|
"attention.wq", "attn_q",
|
||||||
|
"attention.wk", "attn_k",
|
||||||
|
"attention.wv", "attn_v",
|
||||||
|
"attention.wo", "attn_output",
|
||||||
|
"feed_forward.w1", "ffn_gate",
|
||||||
|
"feed_forward.w2", "ffn_down",
|
||||||
|
"feed_forward.w3", "ffn_up",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistralModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
var dims []int
|
||||||
|
for _, dim := range shape {
|
||||||
|
dims = append(dims, int(dim))
|
||||||
|
}
|
||||||
|
|
||||||
|
var heads uint32
|
||||||
|
if strings.HasSuffix(name, "attn_q.weight") {
|
||||||
|
heads = p.NumAttentionHeads
|
||||||
|
} else if strings.HasSuffix(name, "attn_k.weight") {
|
||||||
|
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.T(0, 2, 1, 3); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Reshape(dims...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Transpose(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := native.SelectF32(n, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var f32s []float32
|
||||||
|
for _, t := range ts {
|
||||||
|
f32s = append(f32s, t...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f32s, nil
|
||||||
|
}
|
@ -62,10 +62,7 @@ func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
|||||||
Pattern string
|
Pattern string
|
||||||
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
||||||
}{
|
}{
|
||||||
{"model-*-of-*.safetensors", parseSafetensors},
|
{"*.safetensors", parseSafetensors},
|
||||||
{"model.safetensors", parseSafetensors},
|
|
||||||
{"adapters.safetensors", parseSafetensors},
|
|
||||||
{"adapter_model.safetensors", parseSafetensors},
|
|
||||||
{"pytorch_model-*-of-*.bin", parseTorch},
|
{"pytorch_model-*-of-*.bin", parseTorch},
|
||||||
{"pytorch_model.bin", parseTorch},
|
{"pytorch_model.bin", parseTorch},
|
||||||
{"consolidated.*.pth", parseTorch},
|
{"consolidated.*.pth", parseTorch},
|
||||||
|
190
model/models/mistral/model.go
Normal file
190
model/models/mistral/model.go
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
package llama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
hiddenSize, numHeads, numKVHeads, headDim int
|
||||||
|
eps, ropeBase, ropeScale float32
|
||||||
|
ropeDim uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.BytePairEncoding
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
Layers []Layer `gguf:"blk"`
|
||||||
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
|
*Options
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(c ml.Config) (model.Model, error) {
|
||||||
|
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
||||||
|
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
||||||
|
}
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
|
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
|
Options: &Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
|
ropeDim: c.Uint("rope.dimension_count"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||||
|
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type SelfAttention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
batchSize := hiddenState.Dim(1)
|
||||||
|
ropeType := uint32(0)
|
||||||
|
// Get head dimension - use explicit value if available, otherwise calculate
|
||||||
|
headDim := opts.headDim
|
||||||
|
if headDim == 0 {
|
||||||
|
headDim = opts.hiddenSize / opts.numHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query projection and reshape
|
||||||
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
|
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
|
// Key projection and reshape
|
||||||
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
|
// Value projection and reshape
|
||||||
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
// Attention computation
|
||||||
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
|
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||||
|
|
||||||
|
// Reshape attention output for final projection
|
||||||
|
outputDim := headDim * opts.numHeads
|
||||||
|
kqv = kqv.Reshape(ctx, outputDim, batchSize)
|
||||||
|
|
||||||
|
// Apply output projection
|
||||||
|
return sa.Output.Forward(ctx, kqv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP struct {
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Layer struct {
|
||||||
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||||
|
SelfAttention *SelfAttention
|
||||||
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||||
|
MLP *MLP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
residual := hiddenState
|
||||||
|
|
||||||
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||||
|
|
||||||
|
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||||
|
// we need logits for.
|
||||||
|
if outputs != nil {
|
||||||
|
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||||
|
residual = residual.Rows(ctx, outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = hiddenState.Add(ctx, residual)
|
||||||
|
residual = hiddenState
|
||||||
|
|
||||||
|
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||||
|
return hiddenState.Add(ctx, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||||
|
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
|
||||||
|
var lastLayerOutputs ml.Tensor
|
||||||
|
if i == len(m.Layers)-1 {
|
||||||
|
lastLayerOutputs = outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
return m.Output.Forward(ctx, hiddenState), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("mistral", New)
|
||||||
|
}
|
@ -211,16 +211,10 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var files []string
|
var files []string
|
||||||
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||||
// safetensors files might be unresolved git lfs references; skip if they are
|
// safetensors files might be unresolved git lfs references; skip if they are
|
||||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||||
files = append(files, st...)
|
files = append(files, st...)
|
||||||
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
|
|
||||||
// covers adapters.safetensors
|
|
||||||
files = append(files, st...)
|
|
||||||
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
|
|
||||||
// covers adapter_model.safetensors
|
|
||||||
files = append(files, st...)
|
|
||||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||||
|
Loading…
x
Reference in New Issue
Block a user