generic ggml.array

This commit is contained in:
Michael Yang 2025-04-23 11:22:06 -07:00 committed by Michael Yang
parent 214a7678ea
commit 5d0279164c
2 changed files with 104 additions and 127 deletions

View File

@ -105,32 +105,15 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool {
} }
func (kv KV) Strings(key string, defaultValue ...[]string) []string { func (kv KV) Strings(key string, defaultValue ...[]string) []string {
r := keyValue(kv, key, &array{}) return keyValue(kv, key, &array[string]{}).values
s := make([]string, r.size)
for i := range r.size {
s[i] = r.values[i].(string)
}
return s
} }
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
r := keyValue(kv, key, &array{}) return keyValue(kv, key, &array[uint32]{}).values
s := make([]uint32, r.size)
for i := range r.size {
s[i] = uint32(r.values[i].(int32))
}
return s
} }
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
r := keyValue(kv, key, &array{}) return keyValue(kv, key, &array[float32]{}).values
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
} }
func (kv KV) OllamaEngineRequired() bool { func (kv KV) OllamaEngineRequired() bool {
@ -140,7 +123,12 @@ func (kv KV) OllamaEngineRequired() bool {
}, kv.Architecture()) }, kv.Architecture())
} }
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T { type valueTypes interface {
string | uint32 | uint64 | float32 | bool |
*array[string] | *array[uint32] | *array[uint64] | *array[float32] | *array[bool]
}
func keyValue[T valueTypes](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key key = kv.Architecture() + "." + key
} }
@ -420,7 +408,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount() heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV() headsKV := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size) vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCount() embeddingHeads := f.KV().EmbeddingHeadCount()
embeddingHeadsK := f.KV().EmbeddingHeadCountK() embeddingHeadsK := f.KV().EmbeddingHeadCountK()

View File

@ -36,10 +36,6 @@ type containerGGUF struct {
maxArraySize int maxArraySize int
} }
func (c *containerGGUF) canCollectArray(size int) bool {
return c.maxArraySize < 0 || size <= c.maxArraySize
}
func (c *containerGGUF) Name() string { func (c *containerGGUF) Name() string {
return "gguf" return "gguf"
} }
@ -295,6 +291,23 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
return b.String(), nil return b.String(), nil
} }
func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
for i := range a.size {
if a.values != nil {
e, err := readGGUFV1String(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e
} else {
discardGGUFString(llm, r)
}
}
return a, nil
}
func discardGGUFString(llm *gguf, r io.Reader) error { func discardGGUFString(llm *gguf, r io.Reader) error {
buf := llm.scratch[:8] buf := llm.scratch[:8]
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
@ -352,78 +365,44 @@ func writeGGUFString(w io.Writer, s string) error {
return err return err
} }
type array struct { func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
size int for i := range a.size {
values []any
}
func (a *array) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
n, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, 0, int(n))
}
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
e, err = readGGUFV1String(llm, r)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
if err != nil {
return nil, err
}
if a.values != nil { if a.values != nil {
e, err := readGGUFString(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e a.values[i] = e
} else {
discardGGUFString(llm, r)
} }
} }
return a, nil return a, nil
} }
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) { type array[T any] struct {
if llm.Version == 1 { // size is the actual size of the array
return readGGUFV1Array(llm, r) size int
}
// values is the array of values. this is nil if the array is larger than configured maxSize
values []T
}
func (a *array[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func newArray[T any](size, maxSize int) *array[T] {
a := array[T]{size: size}
if maxSize < 0 || size <= maxSize {
a.values = make([]T, size)
}
return &a
}
func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
t, err := readGGUF[uint32](llm, r) t, err := readGGUF[uint32](llm, r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -434,45 +413,55 @@ func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
return nil, err return nil, err
} }
a := &array{size: int(n)} switch t {
if llm.canCollectArray(int(n)) { case ggufTypeUint8:
a.values = make([]any, int(n)) a := newArray[uint8](int(n), llm.maxArraySize)
} return readGGUFArrayData(llm, r, a)
case ggufTypeInt8:
for i := range n { a := newArray[int8](int(n), llm.maxArraySize)
var e any return readGGUFArrayData(llm, r, a)
switch t { case ggufTypeUint16:
case ggufTypeUint8: a := newArray[uint16](int(n), llm.maxArraySize)
e, err = readGGUF[uint8](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeInt8: case ggufTypeInt16:
e, err = readGGUF[int8](llm, r) a := newArray[int16](int(n), llm.maxArraySize)
case ggufTypeUint16: return readGGUFArrayData(llm, r, a)
e, err = readGGUF[uint16](llm, r) case ggufTypeUint32:
case ggufTypeInt16: a := newArray[uint32](int(n), llm.maxArraySize)
e, err = readGGUF[int16](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeUint32: case ggufTypeInt32:
e, err = readGGUF[uint32](llm, r) a := newArray[int32](int(n), llm.maxArraySize)
case ggufTypeInt32: return readGGUFArrayData(llm, r, a)
e, err = readGGUF[int32](llm, r) case ggufTypeUint64:
case ggufTypeUint64: a := newArray[uint64](int(n), llm.maxArraySize)
e, err = readGGUF[uint64](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeInt64: case ggufTypeInt64:
e, err = readGGUF[int64](llm, r) a := newArray[int64](int(n), llm.maxArraySize)
case ggufTypeFloat32: return readGGUFArrayData(llm, r, a)
e, err = readGGUF[float32](llm, r) case ggufTypeFloat32:
case ggufTypeFloat64: a := newArray[float32](int(n), llm.maxArraySize)
e, err = readGGUF[float64](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeBool: case ggufTypeFloat64:
e, err = readGGUF[bool](llm, r) a := newArray[float64](int(n), llm.maxArraySize)
case ggufTypeString: return readGGUFArrayData(llm, r, a)
if a.values != nil { case ggufTypeBool:
e, err = readGGUFString(llm, r) a := newArray[bool](int(n), llm.maxArraySize)
} else { return readGGUFArrayData(llm, r, a)
err = discardGGUFString(llm, r) case ggufTypeString:
} a := newArray[string](int(n), llm.maxArraySize)
default: if llm.Version == 1 {
return nil, fmt.Errorf("invalid array type: %d", t) return readGGUFV1StringsData(llm, r, a)
} }
return readGGUFStringsData(llm, r, a)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
}
func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) {
for i := range a.size {
e, err := readGGUF[T](llm, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }