model: Load tensors behind an interface
Currently, if a model uses an interface for its data structures (as mllama does) then the tensor data in the structs implementing that interface will not get loaded.
This commit is contained in:
parent
d223f3b697
commit
d650ad398f
@ -147,15 +147,12 @@ func New(s string) (Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
v := reflect.ValueOf(m)
|
v := reflect.ValueOf(m)
|
||||||
v.Elem().Set(populateFields(b, v))
|
v.Elem().Set(populateFields(b, v.Elem()))
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
if t.Kind() == reflect.Pointer {
|
|
||||||
t, v = t.Elem(), v.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.Kind() == reflect.Struct {
|
if t.Kind() == reflect.Struct {
|
||||||
allNil := true
|
allNil := true
|
||||||
@ -205,18 +202,16 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if tt.Kind() == reflect.Pointer {
|
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
|
||||||
vvv := vv.Elem()
|
setPointer(b, vv, tagsCopy)
|
||||||
if vv.IsNil() {
|
|
||||||
vvv = reflect.New(tt.Elem())
|
|
||||||
}
|
|
||||||
|
|
||||||
if f := populateFields(b, vvv, tagsCopy...); f.CanAddr() {
|
|
||||||
vv.Set(f.Addr())
|
|
||||||
}
|
|
||||||
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
||||||
for i := range vv.Len() {
|
for i := range vv.Len() {
|
||||||
vv.Index(i).Set(populateFields(b, vv.Index(i), append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
vvv := vv.Index(i)
|
||||||
|
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||||
|
setPointer(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
|
||||||
|
} else {
|
||||||
|
vvv.Set(populateFields(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,6 +228,26 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setPointer(b ml.Backend, v reflect.Value, tags []Tag) {
|
||||||
|
vv := v
|
||||||
|
if v.Kind() == reflect.Interface {
|
||||||
|
if v.IsNil() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vv = vv.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
vv = vv.Elem()
|
||||||
|
if v.IsNil() {
|
||||||
|
vv = reflect.New(v.Type().Elem()).Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if f := populateFields(b, vv, tags...); f.CanAddr() {
|
||||||
|
v.Set(f.Addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Tag struct {
|
type Tag struct {
|
||||||
Name string
|
Name string
|
||||||
Alternate []string
|
Alternate []string
|
||||||
|
@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) {
|
|||||||
"output_norm.weight",
|
"output_norm.weight",
|
||||||
"output.weight",
|
"output.weight",
|
||||||
},
|
},
|
||||||
}, v))
|
}, v.Elem()))
|
||||||
|
|
||||||
if diff := cmp.Diff(fakeModel{
|
if diff := cmp.Diff(fakeModel{
|
||||||
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
||||||
@ -125,7 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
|||||||
names: []string{
|
names: []string{
|
||||||
"input.weight",
|
"input.weight",
|
||||||
},
|
},
|
||||||
}, v))
|
}, v.Elem()))
|
||||||
|
|
||||||
if diff := cmp.Diff(fakeModel{
|
if diff := cmp.Diff(fakeModel{
|
||||||
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user