From 5d0279164c2fcb4a1d3100d30988ba54ace548d1 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 23 Apr 2025 11:22:06 -0700 Subject: [PATCH] generic ggml.array --- fs/ggml/ggml.go | 32 +++----- fs/ggml/gguf.go | 199 +++++++++++++++++++++++------------------------- 2 files changed, 104 insertions(+), 127 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 9431e9cc1..1ba0813f7 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -105,32 +105,15 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool { } func (kv KV) Strings(key string, defaultValue ...[]string) []string { - r := keyValue(kv, key, &array{}) - s := make([]string, r.size) - for i := range r.size { - s[i] = r.values[i].(string) - } - - return s + return keyValue(kv, key, &array[string]{}).values } func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { - r := keyValue(kv, key, &array{}) - s := make([]uint32, r.size) - for i := range r.size { - s[i] = uint32(r.values[i].(int32)) - } - - return s + return keyValue(kv, key, &array[uint32]{}).values } func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { - r := keyValue(kv, key, &array{}) - s := make([]float32, r.size) - for i := range r.size { - s[i] = float32(r.values[i].(float32)) - } - return s + return keyValue(kv, key, &array[float32]{}).values } func (kv KV) OllamaEngineRequired() bool { @@ -140,7 +123,12 @@ func (kv KV) OllamaEngineRequired() bool { }, kv.Architecture()) } -func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T { +type valueTypes interface { + string | uint32 | uint64 | float32 | bool | + *array[string] | *array[uint32] | *array[uint64] | *array[float32] | *array[bool] +} + +func keyValue[T valueTypes](kv KV, key string, defaultValue ...T) T { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { key = kv.Architecture() + "." + key } @@ -420,7 +408,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri embedding := f.KV().EmbeddingLength() heads := f.KV().HeadCount() headsKV := f.KV().HeadCountKV() - vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size) + vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size) embeddingHeads := f.KV().EmbeddingHeadCount() embeddingHeadsK := f.KV().EmbeddingHeadCountK() diff --git a/fs/ggml/gguf.go b/fs/ggml/gguf.go index 28e89c18a..fb3421576 100644 --- a/fs/ggml/gguf.go +++ b/fs/ggml/gguf.go @@ -36,10 +36,6 @@ type containerGGUF struct { maxArraySize int } -func (c *containerGGUF) canCollectArray(size int) bool { - return c.maxArraySize < 0 || size <= c.maxArraySize -} - func (c *containerGGUF) Name() string { return "gguf" } @@ -295,6 +291,23 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) { return b.String(), nil } +func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) { + for i := range a.size { + if a.values != nil { + e, err := readGGUFV1String(llm, r) + if err != nil { + return nil, err + } + + a.values[i] = e + } else { + discardGGUFString(llm, r) + } + } + + return a, nil +} + func discardGGUFString(llm *gguf, r io.Reader) error { buf := llm.scratch[:8] _, err := io.ReadFull(r, buf) @@ -352,78 +365,44 @@ func writeGGUFString(w io.Writer, s string) error { return err } -type array struct { - size int - values []any -} - -func (a *array) MarshalJSON() ([]byte, error) { - return json.Marshal(a.values) -} - -func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) { - t, err := readGGUF[uint32](llm, r) - if err != nil { - return nil, err - } - - n, err := readGGUF[uint32](llm, r) - if err != nil { - return nil, err - } - - a := &array{size: int(n)} - if llm.canCollectArray(int(n)) { - a.values = make([]any, 0, int(n)) - } - - for i := range n { - var e any - switch t { - case ggufTypeUint8: - e, err = readGGUF[uint8](llm, r) - case ggufTypeInt8: - e, err = readGGUF[int8](llm, r) - case ggufTypeUint16: - e, err = readGGUF[uint16](llm, r) - case ggufTypeInt16: - e, err = readGGUF[int16](llm, r) - case ggufTypeUint32: - e, err = readGGUF[uint32](llm, r) - case ggufTypeInt32: - e, err = readGGUF[int32](llm, r) - case ggufTypeUint64: - e, err = readGGUF[uint64](llm, r) - case ggufTypeInt64: - e, err = readGGUF[int64](llm, r) - case ggufTypeFloat32: - e, err = readGGUF[float32](llm, r) - case ggufTypeFloat64: - e, err = readGGUF[float64](llm, r) - case ggufTypeBool: - e, err = readGGUF[bool](llm, r) - case ggufTypeString: - e, err = readGGUFV1String(llm, r) - default: - return nil, fmt.Errorf("invalid array type: %d", t) - } - if err != nil { - return nil, err - } - +func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) { + for i := range a.size { if a.values != nil { + e, err := readGGUFString(llm, r) + if err != nil { + return nil, err + } + a.values[i] = e + } else { + discardGGUFString(llm, r) } } return a, nil } -func readGGUFArray(llm *gguf, r io.Reader) (*array, error) { - if llm.Version == 1 { - return readGGUFV1Array(llm, r) - } +type array[T any] struct { + // size is the actual size of the array + size int + // values is the array of values. this is nil if the array is larger than configured maxSize + values []T +} + +func (a *array[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(a.values) +} + +func newArray[T any](size, maxSize int) *array[T] { + a := array[T]{size: size} + if maxSize < 0 || size <= maxSize { + a.values = make([]T, size) + } + return &a +} + +func readGGUFArray(llm *gguf, r io.Reader) (any, error) { t, err := readGGUF[uint32](llm, r) if err != nil { return nil, err @@ -434,45 +413,55 @@ func readGGUFArray(llm *gguf, r io.Reader) (*array, error) { return nil, err } - a := &array{size: int(n)} - if llm.canCollectArray(int(n)) { - a.values = make([]any, int(n)) - } - - for i := range n { - var e any - switch t { - case ggufTypeUint8: - e, err = readGGUF[uint8](llm, r) - case ggufTypeInt8: - e, err = readGGUF[int8](llm, r) - case ggufTypeUint16: - e, err = readGGUF[uint16](llm, r) - case ggufTypeInt16: - e, err = readGGUF[int16](llm, r) - case ggufTypeUint32: - e, err = readGGUF[uint32](llm, r) - case ggufTypeInt32: - e, err = readGGUF[int32](llm, r) - case ggufTypeUint64: - e, err = readGGUF[uint64](llm, r) - case ggufTypeInt64: - e, err = readGGUF[int64](llm, r) - case ggufTypeFloat32: - e, err = readGGUF[float32](llm, r) - case ggufTypeFloat64: - e, err = readGGUF[float64](llm, r) - case ggufTypeBool: - e, err = readGGUF[bool](llm, r) - case ggufTypeString: - if a.values != nil { - e, err = readGGUFString(llm, r) - } else { - err = discardGGUFString(llm, r) - } - default: - return nil, fmt.Errorf("invalid array type: %d", t) + switch t { + case ggufTypeUint8: + a := newArray[uint8](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeInt8: + a := newArray[int8](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeUint16: + a := newArray[uint16](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeInt16: + a := newArray[int16](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeUint32: + a := newArray[uint32](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeInt32: + a := newArray[int32](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeUint64: + a := newArray[uint64](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeInt64: + a := newArray[int64](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeFloat32: + a := newArray[float32](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeFloat64: + a := newArray[float64](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeBool: + a := newArray[bool](int(n), llm.maxArraySize) + return readGGUFArrayData(llm, r, a) + case ggufTypeString: + a := newArray[string](int(n), llm.maxArraySize) + if llm.Version == 1 { + return readGGUFV1StringsData(llm, r, a) } + + return readGGUFStringsData(llm, r, a) + default: + return nil, fmt.Errorf("invalid array type: %d", t) + } +} + +func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) { + for i := range a.size { + e, err := readGGUF[T](llm, r) if err != nil { return nil, err }