diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index dd8a01d95..90d1d4406 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -157,15 +157,13 @@ func (ts Tensors) GroupLayers() map[string]Layer { layers := make(map[string]Layer) for _, t := range ts.items { parts := strings.Split(t.Name, ".") - if i := slices.Index(parts, "blk"); i > 0 { - parts = append([]string{ - strings.Join(parts[:i], "."), - strings.Join(parts[i:i+2], "."), - }, parts[i+2:]...) - } else if i == 0 { - parts = append([]string{ - strings.Join(parts[i:i+2], "."), - }, parts[i+2:]...) + if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 { + if len(parts) > index+2 { + // blk and mm should have a number after them, join it + parts = append( + []string{strings.Join(parts[:index+2], ".")}, + parts[index+2:]...) + } } if _, ok := layers[parts[0]]; !ok { diff --git a/fs/ggml/ggml_test.go b/fs/ggml/ggml_test.go index 93aa95adb..4fcdf0854 100644 --- a/fs/ggml/ggml_test.go +++ b/fs/ggml/ggml_test.go @@ -85,23 +85,25 @@ func TestTensorLayers(t *testing.T) { } }), want: map[string]Layer{ - "mm": { - "0.bias": tensors["mm.0.bias"], - "0.weight": tensors["mm.0.weight"], + "mm.0": { + "bias": tensors["mm.0.bias"], + "weight": tensors["mm.0.weight"], + }, + "v.blk.0": { + "attn_k.weight": tensors["v.blk.0.attn_k.weight"], + "attn_q.weight": tensors["v.blk.0.attn_q.weight"], + "attn_v.weight": tensors["v.blk.0.attn_v.weight"], + "attn_output.weight": tensors["v.blk.0.attn_output.weight"], + "attn_norm.weight": tensors["v.blk.0.attn_norm.weight"], + "ffn_down.weight": tensors["v.blk.0.ffn_down.weight"], + "ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"], + "ffn_up.weight": tensors["v.blk.0.ffn_up.weight"], + "ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"], }, "v": { - "blk.0.attn_k.weight": tensors["v.blk.0.attn_k.weight"], - "blk.0.attn_q.weight": tensors["v.blk.0.attn_q.weight"], - "blk.0.attn_v.weight": tensors["v.blk.0.attn_v.weight"], - "blk.0.attn_output.weight": tensors["v.blk.0.attn_output.weight"], - "blk.0.attn_norm.weight": tensors["v.blk.0.attn_norm.weight"], - "blk.0.ffn_down.weight": tensors["v.blk.0.ffn_down.weight"], - "blk.0.ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"], - "blk.0.ffn_up.weight": tensors["v.blk.0.ffn_up.weight"], - "blk.0.ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"], - "patch_embd.weight": tensors["v.patch_embd.weight"], - "position_embd.gate": tensors["v.position_embd.gate"], - "position_embd.weight": tensors["v.position_embd.weight"], + "patch_embd.weight": tensors["v.patch_embd.weight"], + "position_embd.gate": tensors["v.position_embd.gate"], + "position_embd.weight": tensors["v.position_embd.weight"], }, }, }, @@ -122,23 +124,25 @@ func TestTensorLayers(t *testing.T) { }, "token_embd": {"weight": tensors["token_embd.weight"]}, "output_norm": {"weight": tensors["output_norm.weight"]}, - "mm": { - "0.bias": tensors["mm.0.bias"], - "0.weight": tensors["mm.0.weight"], + "mm.0": { + "bias": tensors["mm.0.bias"], + "weight": tensors["mm.0.weight"], + }, + "v.blk.0": { + "attn_k.weight": tensors["v.blk.0.attn_k.weight"], + "attn_q.weight": tensors["v.blk.0.attn_q.weight"], + "attn_v.weight": tensors["v.blk.0.attn_v.weight"], + "attn_output.weight": tensors["v.blk.0.attn_output.weight"], + "attn_norm.weight": tensors["v.blk.0.attn_norm.weight"], + "ffn_down.weight": tensors["v.blk.0.ffn_down.weight"], + "ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"], + "ffn_up.weight": tensors["v.blk.0.ffn_up.weight"], + "ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"], }, "v": { - "blk.0.attn_k.weight": tensors["v.blk.0.attn_k.weight"], - "blk.0.attn_q.weight": tensors["v.blk.0.attn_q.weight"], - "blk.0.attn_v.weight": tensors["v.blk.0.attn_v.weight"], - "blk.0.attn_output.weight": tensors["v.blk.0.attn_output.weight"], - "blk.0.attn_norm.weight": tensors["v.blk.0.attn_norm.weight"], - "blk.0.ffn_down.weight": tensors["v.blk.0.ffn_down.weight"], - "blk.0.ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"], - "blk.0.ffn_up.weight": tensors["v.blk.0.ffn_up.weight"], - "blk.0.ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"], - "patch_embd.weight": tensors["v.patch_embd.weight"], - "position_embd.gate": tensors["v.position_embd.gate"], - "position_embd.weight": tensors["v.position_embd.weight"], + "patch_embd.weight": tensors["v.patch_embd.weight"], + "position_embd.gate": tensors["v.position_embd.gate"], + "position_embd.weight": tensors["v.position_embd.weight"], }, }, },