draft: mlx

This commit is contained in:
Michael Yang 2024-12-31 11:13:09 -08:00
parent 67bcb55941
commit 6f09d63862
11 changed files with 64370 additions and 72 deletions

View File

@ -2,33 +2,9 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
include(CheckLanguage)
find_package(Threads REQUIRED)
set(CMAKE_BUILD_TYPE Release)
set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(GGML_CCACHE ON)
set(GGML_SCHED_MAX_COPIES 4)
set(GGML_CPU_ALL_VARIANTS ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_LLAMAFILE ON)
add_compile_definitions(GGML_BUILD)
add_compile_definitions(GGML_SHARED)
add_compile_definitions(GGML_BACKEND_DL)
add_compile_definitions(GGML_BACKEND_SHARED)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
function(set_target_output_directory _target)
if(TARGET ${_target})
set_target_properties(${_target} PROPERTIES
@ -39,30 +15,8 @@ function(set_target_output_directory _target)
endif()
endfunction()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src EXCLUDE_FROM_ALL)
set_target_output_directory(ggml-base)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
set_target_output_directory(ggml-cpu)
find_package(BLAS)
if(NOT BLAS_VENDOR)
set(GGML_BLAS_VENDOR "Generic")
else()
set(GGML_BLAS_VENDOR ${BLAS_VENDOR})
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-blas)
set_target_output_directory(ggml-blas)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
set_target_output_directory(ggml-cuda)
endif()
check_language(HIP)
if(CMAKE_HIP_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
set_target_output_directory(ggml-hip)
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/mlx)
endif()

View File

@ -35,7 +35,7 @@ func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
}
func NewBackend(f *os.File) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
if backend, ok := backends["mlx"]; ok {
return backend(f)
}
@ -67,8 +67,8 @@ type Tensor interface {
Mulmat(ctx Context, t2 Tensor) Tensor
Softmax(ctx Context) Tensor
Norm(ctx Context, eps float32) Tensor
RMSNorm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

View File

@ -2,4 +2,5 @@ package backend
import (
_ "github.com/ollama/ollama/ml/backend/ggml"
_ "github.com/ollama/ollama/ml/backend/mlx"
)

View File

@ -0,0 +1,51 @@
include(CheckLanguage)
find_package(Threads REQUIRED)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(GGML_CCACHE ON)
set(GGML_SCHED_MAX_COPIES 4)
set(GGML_CPU_ALL_VARIANTS ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_LLAMAFILE ON)
add_compile_definitions(GGML_BUILD)
add_compile_definitions(GGML_SHARED)
add_compile_definitions(GGML_BACKEND_DL)
add_compile_definitions(GGML_BACKEND_SHARED)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-cpu)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-cpu/amx)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src EXCLUDE_FROM_ALL)
set_target_output_directory(ggml-base)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-cpu)
set_target_output_directory(ggml-cpu)
find_package(BLAS)
if(NOT BLAS_VENDOR)
set(GGML_BLAS_VENDOR "Generic")
else()
set(GGML_BLAS_VENDOR ${BLAS_VENDOR})
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-blas)
set_target_output_directory(ggml-blas)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-cuda)
set_target_output_directory(ggml-cuda)
endif()
check_language(HIP)
if(CMAKE_HIP_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-hip)
set_target_output_directory(ggml-hip)
endif()

View File

@ -423,16 +423,28 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) Norm(ctx ml.Context, eps float32) ml.Tensor {
return &Tensor{
t: C.ggml_norm(ctx.(*Context).ctx, t.t, (C.float)(eps)),
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
r := (&Tensor{
t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
}).Mul(ctx, w)
if b != nil {
r = r.Add(ctx, b)
}
return r
}
func (t *Tensor) RMSNorm(ctx ml.Context, eps float32) ml.Tensor {
return &Tensor{
func (t *Tensor) RMSNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
r := (&Tensor{
t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
}).Mul(ctx, w)
if b != nil {
r = r.Add(ctx, b)
}
return r
}
func (t *Tensor) Pad(ctx ml.Context, shape ...int64) ml.Tensor {

View File

@ -0,0 +1,26 @@
include(FetchContent)
set(MLX_C_BUILD_EXAMPLES OFF)
set(MLX_BUILD_GGUF OFF)
set(MLX_BUILD_SAFETENSORS OFF)
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(NOT MLX_METAL_VERSION)
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG v0.1.0)
FetchContent_MakeAvailable(mlx-c)
set_target_output_directory(mlx)
set_target_output_directory(mlxc)

364
ml/backend/mlx/mlx.go Normal file
View File

@ -0,0 +1,364 @@
package mlx
// #cgo CPPFLAGS: -I${SRCDIR}/../../../build/_deps/mlx-c-src
// #cgo LDFLAGS: -L${SRCDIR}/../../../build/lib -lmlxc -lmlx
// #cgo LDFLAGS: -framework Accelerate
// #cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../../build/lib
// #include <stdlib.h>
// #include "mlx/c/array.h"
// #include "mlx/c/fast.h"
// #include "mlx/c/ops.h"
// #include "mlx/c/stream.h"
import "C"
import (
"bytes"
"fmt"
"io"
"log/slog"
"os"
"sync"
"unsafe"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
"golang.org/x/sync/errgroup"
)
func init() {
ml.RegisterBackend("mlx", New)
}
func New(r *os.File) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1)
if err != nil {
return nil, err
}
tensors := make(map[string]*Array, len(meta.Tensors().Items()))
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
stream := C.mlx_default_cpu_stream_new()
var g errgroup.Group
var mu sync.Mutex
for _, t := range meta.Tensors().Items() {
g.Go(func() error {
var b bytes.Buffer
n, err := io.Copy(&b, io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())))
if err != nil {
return err
}
if n != int64(t.Size()) {
return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
}
cbytes := C.CBytes(b.Bytes())
defer C.free(cbytes)
shape := make([]C.int, len(t.Shape))
for i, dim := range t.Shape {
shape[i] = C.int(dim)
}
var dtype C.mlx_dtype
switch t.Kind {
case 0:
dtype = C.MLX_FLOAT32
case 1:
dtype = C.MLX_FLOAT16
default:
return fmt.Errorf("unsupported dtype %d", t.Kind)
}
mu.Lock()
defer mu.Unlock()
var a C.mlx_array
C.mlx_transpose_all(
&a,
C.mlx_array_new_data(
cbytes,
(*C.int)(&shape[0]),
C.int(len(shape)),
dtype,
),
stream,
)
tensors[t.Name] = &Array{
name: t.Name,
a: a,
}
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return &Backend{
meta: meta,
tensors: tensors,
}, nil
}
type Backend struct {
meta *fs.GGML
tensors map[string]*Array
}
// Config implements ml.Backend.
func (b *Backend) Config() ml.Config {
return b.meta.KV()
}
// Get implements ml.Backend.
func (b *Backend) Get(name string) ml.Tensor {
if a, ok := b.tensors[name]; ok {
return a
}
return nil
}
func (b *Backend) NewContext() ml.Context {
return &Context{
stream: C.mlx_default_cpu_stream_new(),
}
}
type Context struct {
stream C.mlx_stream
}
// Close implements ml.Context.
func (c *Context) Close() error {
panic("unimplemented")
}
// Compute implements ml.Context.
func (c *Context) Compute(ml.Tensor) ml.Tensor {
panic("unimplemented")
}
// Forward implements ml.Context.
func (c *Context) Forward(ml.Tensor) {
panic("unimplemented")
}
// FromFloatSlice implements ml.Context.
func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
panic("unimplemented")
}
// FromIntSlice implements ml.Context.
func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
cshape := make([]C.int, len(shape))
for i, dim := range shape {
cshape[i] = C.int(dim)
}
return &Array{
a: C.mlx_array_new_data(
unsafe.Pointer(&s[0]),
(*C.int)(&cshape[0]),
C.int(len(cshape)),
C.MLX_INT32,
),
}, nil
}
// Zeros implements ml.Context.
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
panic("unimplemented")
}
type Array struct {
name string
a C.mlx_array
}
func (a *Array) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", a.name),
slog.Any("shape", a.Shape()),
)
}
// Add implements ml.Tensor.
func (a *Array) Add(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
panic("unimplemented")
}
// Bytes implements ml.Tensor.
func (a *Array) Bytes() []byte {
panic("unimplemented")
}
// Concat implements ml.Tensor.
func (a *Array) Concat(ctx ml.Context, a2 ml.Tensor, dim int) ml.Tensor {
panic("unimplemented")
}
// Contiguous implements ml.Tensor.
func (a *Array) Contiguous(ctx ml.Context) ml.Tensor {
panic("unimplemented")
}
// Conv2D implements ml.Tensor.
func (a *Array) Conv2D(ctx ml.Context, weight ml.Tensor, s0 int, s1 int, p0 int, p1 int, d0 int, d1 int) ml.Tensor {
panic("unimplemented")
}
// Copy implements ml.Tensor.
func (a *Array) Copy(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
panic("unimplemented")
}
// DType implements ml.Tensor.
func (a *Array) DType() ml.DType {
panic("unimplemented")
}
// Dim implements ml.Tensor.
func (a *Array) Dim(n int) int64 {
return int64(C.mlx_array_dim(a.a, C.int(n)))
}
// Floats implements ml.Tensor.
func (a *Array) Floats() []float32 {
panic("unimplemented")
}
// GELU implements ml.Tensor.
func (a *Array) GELU(ctx ml.Context) ml.Tensor {
panic("unimplemented")
}
// Mul implements ml.Tensor.
func (a *Array) Mul(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
panic("unimplemented")
}
// Mulmat implements ml.Tensor.
func (a *Array) Mulmat(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
slog.Info("mulmat", "a", a, "a2", a2)
var r C.mlx_array
C.mlx_matmul(&r, a2.(*Array).a, a.Permute(1, 0, 2, 3), ctx.(*Context).stream)
return &Array{a: r}
}
// LayerNorm implements ml.Tensor.
func (a *Array) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
var r C.mlx_array
C.mlx_fast_layer_norm(
&r,
a.a,
w.(*Array).a,
b.(*Array).a,
C.float(eps),
ctx.(*Context).stream,
)
return &Array{a: r}
}
// Pad implements ml.Tensor.
func (a *Array) Pad(ctx ml.Context, shape ...int64) ml.Tensor {
panic("unimplemented")
}
// Permute implements ml.Tensor.
func (a *Array) Permute(ctx ml.Context, shape ...int) ml.Tensor {
panic("unimplemented")
}
// RMSNorm implements ml.Tensor.
func (a *Array) RMSNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
var r C.mlx_array
C.mlx_fast_rms_norm(
&r,
a.a,
w.(*Array).a,
C.float(eps),
ctx.(*Context).stream,
)
return &Array{a: r}
}
// Reshape implements ml.Tensor.
func (a *Array) Reshape(ctx ml.Context, shape ...int64) ml.Tensor {
cshape := make([]C.int, len(shape))
for i, dim := range shape {
cshape[i] = C.int(dim)
}
var r C.mlx_array
C.mlx_reshape(&r, a.a, (*C.int)(&cshape[0]), C.size_t(len(cshape)), ctx.(*Context).stream)
return &Array{a: r}
}
// Rope implements ml.Tensor.
func (a *Array) Rope(ctx ml.Context, positionIDs ml.Tensor, ropeFactors ml.Tensor, dim uint32, base float32, scale float32) ml.Tensor {
panic("unimplemented")
}
// Rows implements ml.Tensor.
func (a *Array) Rows(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
var r C.mlx_array
slog.Info("rows", "a", a, "a2", a2)
C.mlx_take(&r, a.a, a2.(*Array).a, 0, ctx.(*Context).stream)
return &Array{a: r}
}
// SILU implements ml.Tensor.
func (a *Array) SILU(ctx ml.Context) ml.Tensor {
panic("unimplemented")
}
// Scale implements ml.Tensor.
func (a *Array) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("unimplemented")
}
// Shape implements ml.Tensor.
func (a *Array) Shape() []int64 {
shape := make([]int64, C.mlx_array_ndim(a.a))
for i := range shape {
shape[i] = int64(C.mlx_array_dim(a.a, C.int(i)))
}
return shape
}
// Softmax implements ml.Tensor.
func (a *Array) Softmax(ctx ml.Context) ml.Tensor {
panic("unimplemented")
}
// Stack implements ml.Tensor.
func (a *Array) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
panic("unimplemented")
}
// Stride implements ml.Tensor.
func (a *Array) Stride(n int) int64 {
panic("unimplemented")
}
// Tanh implements ml.Tensor.
func (a *Array) Tanh(ctx ml.Context) ml.Tensor {
panic("unimplemented")
}
// Unpad implements ml.Tensor.
func (a *Array) Unpad(ctx ml.Context, shape ...int64) ml.Tensor {
panic("unimplemented")
}
// View implements ml.Tensor.
func (a *Array) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
panic("unimplemented")
}

View File

@ -10,12 +10,7 @@ type LayerNorm struct {
}
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
t = t.Norm(ctx, eps).Mul(ctx, m.Weight)
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
}
type RMSNorm struct {
@ -24,10 +19,5 @@ type RMSNorm struct {
}
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
t = t.RMSNorm(ctx, eps).Mul(ctx, m.Weight)
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
return t.RMSNorm(ctx, m.Weight, m.Bias, eps)
}

View File

@ -1,6 +1,7 @@
package llama
import (
"log/slog"
"math"
"github.com/ollama/ollama/ml"
@ -52,16 +53,16 @@ type SelfAttention struct {
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
batchSize := hiddenState.Dim(0)
headDim := opts.hiddenSize / opts.numHeads
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.Rope(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
// q = q.Rope(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.Rope(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
// k = k.Rope(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -72,6 +73,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
slog.Info("self attention", "q", q, "k", k, "v", v)
kq := k.Mulmat(ctx, q)
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
kq = kq.Softmax(ctx)
@ -127,6 +130,8 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
slog.Info("breakpoint", "inputs", inputs, "positions", positions, "hiddenState", hiddenState)
for i, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options)
}

View File

@ -0,0 +1,50 @@
package mllama
import (
"errors"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/model"
)
func BenchmarkProcessText(b *testing.B) {
ours, err := model.New(filepath.Join("testdata", "model.bin"))
if errors.Is(err, os.ErrNotExist) {
b.Skip("no model.bin")
} else if err != nil {
b.Fatal(err)
}
var ids []int32
b.Run("encode", func(b *testing.B) {
txt, err := os.ReadFile(filepath.Join("..", "testdata", "war-and-peace.txt"))
if err != nil {
b.Fatal(err)
}
for i := 0; i < b.N; i++ {
b.StopTimer()
b.StartTimer()
ids, err = ours.(model.TextProcessor).Encode(string(txt))
if err != nil {
b.Fatal(err)
}
}
})
b.Run("decode", func(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
b.StartTimer()
_, err = ours.(model.TextProcessor).Decode(ids)
if err != nil {
b.Fatal(err)
}
}
})
}

63845
model/testdata/war-and-peace.txt vendored Normal file

File diff suppressed because it is too large Load Diff