revert GroupLayers
This commit is contained in:
parent
3241b45790
commit
a3e0df1a5d
@ -157,15 +157,13 @@ func (ts Tensors) GroupLayers() map[string]Layer {
|
||||
layers := make(map[string]Layer)
|
||||
for _, t := range ts.items {
|
||||
parts := strings.Split(t.Name, ".")
|
||||
if i := slices.Index(parts, "blk"); i > 0 {
|
||||
parts = append([]string{
|
||||
strings.Join(parts[:i], "."),
|
||||
strings.Join(parts[i:i+2], "."),
|
||||
}, parts[i+2:]...)
|
||||
} else if i == 0 {
|
||||
parts = append([]string{
|
||||
strings.Join(parts[i:i+2], "."),
|
||||
}, parts[i+2:]...)
|
||||
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
||||
if len(parts) > index+2 {
|
||||
// blk and mm should have a number after them, join it
|
||||
parts = append(
|
||||
[]string{strings.Join(parts[:index+2], ".")},
|
||||
parts[index+2:]...)
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := layers[parts[0]]; !ok {
|
||||
|
@ -85,23 +85,25 @@ func TestTensorLayers(t *testing.T) {
|
||||
}
|
||||
}),
|
||||
want: map[string]Layer{
|
||||
"mm": {
|
||||
"0.bias": tensors["mm.0.bias"],
|
||||
"0.weight": tensors["mm.0.weight"],
|
||||
"mm.0": {
|
||||
"bias": tensors["mm.0.bias"],
|
||||
"weight": tensors["mm.0.weight"],
|
||||
},
|
||||
"v.blk.0": {
|
||||
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||
},
|
||||
"v": {
|
||||
"blk.0.attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||
"blk.0.attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||
"blk.0.attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||
"blk.0.attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||
"blk.0.attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||
"blk.0.ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||
"blk.0.ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||
"blk.0.ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||
"blk.0.ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -122,23 +124,25 @@ func TestTensorLayers(t *testing.T) {
|
||||
},
|
||||
"token_embd": {"weight": tensors["token_embd.weight"]},
|
||||
"output_norm": {"weight": tensors["output_norm.weight"]},
|
||||
"mm": {
|
||||
"0.bias": tensors["mm.0.bias"],
|
||||
"0.weight": tensors["mm.0.weight"],
|
||||
"mm.0": {
|
||||
"bias": tensors["mm.0.bias"],
|
||||
"weight": tensors["mm.0.weight"],
|
||||
},
|
||||
"v.blk.0": {
|
||||
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||
},
|
||||
"v": {
|
||||
"blk.0.attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||
"blk.0.attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||
"blk.0.attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||
"blk.0.attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||
"blk.0.attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||
"blk.0.ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||
"blk.0.ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||
"blk.0.ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||
"blk.0.ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||
},
|
||||
},
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user