revert GroupLayers

This commit is contained in:
Michael Yang 2025-02-13 15:35:21 -08:00
parent 3241b45790
commit a3e0df1a5d
2 changed files with 41 additions and 39 deletions

View File

@ -157,15 +157,13 @@ func (ts Tensors) GroupLayers() map[string]Layer {
layers := make(map[string]Layer)
for _, t := range ts.items {
parts := strings.Split(t.Name, ".")
if i := slices.Index(parts, "blk"); i > 0 {
parts = append([]string{
strings.Join(parts[:i], "."),
strings.Join(parts[i:i+2], "."),
}, parts[i+2:]...)
} else if i == 0 {
parts = append([]string{
strings.Join(parts[i:i+2], "."),
}, parts[i+2:]...)
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
if len(parts) > index+2 {
// blk and mm should have a number after them, join it
parts = append(
[]string{strings.Join(parts[:index+2], ".")},
parts[index+2:]...)
}
}
if _, ok := layers[parts[0]]; !ok {

View File

@ -85,23 +85,25 @@ func TestTensorLayers(t *testing.T) {
}
}),
want: map[string]Layer{
"mm": {
"0.bias": tensors["mm.0.bias"],
"0.weight": tensors["mm.0.weight"],
"mm.0": {
"bias": tensors["mm.0.bias"],
"weight": tensors["mm.0.weight"],
},
"v.blk.0": {
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
},
"v": {
"blk.0.attn_k.weight": tensors["v.blk.0.attn_k.weight"],
"blk.0.attn_q.weight": tensors["v.blk.0.attn_q.weight"],
"blk.0.attn_v.weight": tensors["v.blk.0.attn_v.weight"],
"blk.0.attn_output.weight": tensors["v.blk.0.attn_output.weight"],
"blk.0.attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
"blk.0.ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
"blk.0.ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
"blk.0.ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
"blk.0.ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
"patch_embd.weight": tensors["v.patch_embd.weight"],
"position_embd.gate": tensors["v.position_embd.gate"],
"position_embd.weight": tensors["v.position_embd.weight"],
"patch_embd.weight": tensors["v.patch_embd.weight"],
"position_embd.gate": tensors["v.position_embd.gate"],
"position_embd.weight": tensors["v.position_embd.weight"],
},
},
},
@ -122,23 +124,25 @@ func TestTensorLayers(t *testing.T) {
},
"token_embd": {"weight": tensors["token_embd.weight"]},
"output_norm": {"weight": tensors["output_norm.weight"]},
"mm": {
"0.bias": tensors["mm.0.bias"],
"0.weight": tensors["mm.0.weight"],
"mm.0": {
"bias": tensors["mm.0.bias"],
"weight": tensors["mm.0.weight"],
},
"v.blk.0": {
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
},
"v": {
"blk.0.attn_k.weight": tensors["v.blk.0.attn_k.weight"],
"blk.0.attn_q.weight": tensors["v.blk.0.attn_q.weight"],
"blk.0.attn_v.weight": tensors["v.blk.0.attn_v.weight"],
"blk.0.attn_output.weight": tensors["v.blk.0.attn_output.weight"],
"blk.0.attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
"blk.0.ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
"blk.0.ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
"blk.0.ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
"blk.0.ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
"patch_embd.weight": tensors["v.patch_embd.weight"],
"position_embd.gate": tensors["v.position_embd.gate"],
"position_embd.weight": tensors["v.position_embd.weight"],
"patch_embd.weight": tensors["v.patch_embd.weight"],
"position_embd.gate": tensors["v.position_embd.gate"],
"position_embd.weight": tensors["v.position_embd.weight"],
},
},
},