revert GroupLayers

This commit is contained in:
Michael Yang 2025-02-13 15:35:21 -08:00
parent 3241b45790
commit a3e0df1a5d
2 changed files with 41 additions and 39 deletions

View File

@ -157,15 +157,13 @@ func (ts Tensors) GroupLayers() map[string]Layer {
layers := make(map[string]Layer) layers := make(map[string]Layer)
for _, t := range ts.items { for _, t := range ts.items {
parts := strings.Split(t.Name, ".") parts := strings.Split(t.Name, ".")
if i := slices.Index(parts, "blk"); i > 0 { if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
parts = append([]string{ if len(parts) > index+2 {
strings.Join(parts[:i], "."), // blk and mm should have a number after them, join it
strings.Join(parts[i:i+2], "."), parts = append(
}, parts[i+2:]...) []string{strings.Join(parts[:index+2], ".")},
} else if i == 0 { parts[index+2:]...)
parts = append([]string{ }
strings.Join(parts[i:i+2], "."),
}, parts[i+2:]...)
} }
if _, ok := layers[parts[0]]; !ok { if _, ok := layers[parts[0]]; !ok {

View File

@ -85,20 +85,22 @@ func TestTensorLayers(t *testing.T) {
} }
}), }),
want: map[string]Layer{ want: map[string]Layer{
"mm": { "mm.0": {
"0.bias": tensors["mm.0.bias"], "bias": tensors["mm.0.bias"],
"0.weight": tensors["mm.0.weight"], "weight": tensors["mm.0.weight"],
},
"v.blk.0": {
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
}, },
"v": { "v": {
"blk.0.attn_k.weight": tensors["v.blk.0.attn_k.weight"],
"blk.0.attn_q.weight": tensors["v.blk.0.attn_q.weight"],
"blk.0.attn_v.weight": tensors["v.blk.0.attn_v.weight"],
"blk.0.attn_output.weight": tensors["v.blk.0.attn_output.weight"],
"blk.0.attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
"blk.0.ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
"blk.0.ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
"blk.0.ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
"blk.0.ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
"patch_embd.weight": tensors["v.patch_embd.weight"], "patch_embd.weight": tensors["v.patch_embd.weight"],
"position_embd.gate": tensors["v.position_embd.gate"], "position_embd.gate": tensors["v.position_embd.gate"],
"position_embd.weight": tensors["v.position_embd.weight"], "position_embd.weight": tensors["v.position_embd.weight"],
@ -122,20 +124,22 @@ func TestTensorLayers(t *testing.T) {
}, },
"token_embd": {"weight": tensors["token_embd.weight"]}, "token_embd": {"weight": tensors["token_embd.weight"]},
"output_norm": {"weight": tensors["output_norm.weight"]}, "output_norm": {"weight": tensors["output_norm.weight"]},
"mm": { "mm.0": {
"0.bias": tensors["mm.0.bias"], "bias": tensors["mm.0.bias"],
"0.weight": tensors["mm.0.weight"], "weight": tensors["mm.0.weight"],
},
"v.blk.0": {
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
}, },
"v": { "v": {
"blk.0.attn_k.weight": tensors["v.blk.0.attn_k.weight"],
"blk.0.attn_q.weight": tensors["v.blk.0.attn_q.weight"],
"blk.0.attn_v.weight": tensors["v.blk.0.attn_v.weight"],
"blk.0.attn_output.weight": tensors["v.blk.0.attn_output.weight"],
"blk.0.attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
"blk.0.ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
"blk.0.ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
"blk.0.ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
"blk.0.ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
"patch_embd.weight": tensors["v.patch_embd.weight"], "patch_embd.weight": tensors["v.patch_embd.weight"],
"position_embd.gate": tensors["v.position_embd.gate"], "position_embd.gate": tensors["v.position_embd.gate"],
"position_embd.weight": tensors["v.position_embd.weight"], "position_embd.weight": tensors["v.position_embd.weight"],