diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e7e47c964..27e229fcf 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -237,5 +237,5 @@ jobs: - uses: actions/checkout@v4 - name: Verify patches apply cleanly and do not change files run: | - make -f Makefile.sync clean sync - git diff --compact-summary --exit-code + make -f Makefile.sync clean checkout apply-patches sync + git diff --compact-summary --exit-code \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index e2447f32a..4990bc6f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,7 +51,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx) set(GGML_CPU ON) -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml) set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE) get_target_property(CPU_VARIANTS ggml-cpu MANUALLY_ADDED_DEPENDENCIES) diff --git a/Makefile.sync b/Makefile.sync index 2bb8d4ab8..4a6a1039d 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -1,6 +1,6 @@ UPSTREAM=https://github.com/ggerganov/llama.cpp.git WORKDIR=llama/vendor -FETCH_HEAD=d7cfe1ffe0f435d0048a6058d529daf76e072d9c +FETCH_HEAD=71e90e8813f90097701e62f7fce137d96ddf41e2 .PHONY: help help: @@ -15,18 +15,18 @@ help: @echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync" .PHONY: sync -sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml apply-patches +sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml .PHONY: llama/build-info.cpp llama/build-info.cpp: llama/build-info.cpp.in sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@ .PHONY: llama/llama.cpp -llama/llama.cpp: llama/vendor/ apply-patches +llama/llama.cpp: llama/vendor/ rsync -arvzc -f "merge $@/.rsync-filter" $< $@ -.PHONY: ml/backend/ggml/ggml apply-patches -ml/backend/ggml/ggml: llama/vendor/ggml/ apply-patches +.PHONY: ml/backend/ggml/ggml +ml/backend/ggml/ggml: llama/vendor/ggml/ rsync -arvzc -f "merge $@/.rsync-filter" $< $@ PATCHES=$(wildcard llama/patches/*.patch) diff --git a/llama/build-info.cpp b/llama/build-info.cpp index 58c1b0801..06deb0dde 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "d7cfe1ffe0f435d0048a6058d529daf76e072d9c"; +char const *LLAMA_COMMIT = "71e90e8813f90097701e62f7fce137d96ddf41e2"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/.rsync-filter b/llama/llama.cpp/.rsync-filter index 186e1c12e..9bc06684b 100644 --- a/llama/llama.cpp/.rsync-filter +++ b/llama/llama.cpp/.rsync-filter @@ -13,6 +13,7 @@ include include/llama-*.* include examples/ include examples/llava/ include examples/llava/clip.* +include examples/llava/clip-impl.* include examples/llava/llava.* include src/ include src/llama.* diff --git a/llama/llama.cpp/common/common.cpp b/llama/llama.cpp/common/common.cpp index d2b0d50e3..94f545f81 100644 --- a/llama/llama.cpp/common/common.cpp +++ b/llama/llama.cpp/common/common.cpp @@ -7,10 +7,6 @@ #include "common.h" #include "log.h" -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include "json.hpp" -#include "json-schema-to-grammar.h" #include "llama.h" #include @@ -52,47 +48,11 @@ #include #include #endif -#if defined(LLAMA_USE_CURL) -#include -#include -#include -#endif #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif -#if defined(LLAMA_USE_CURL) -#ifdef __linux__ -#include -#elif defined(_WIN32) -# if !defined(PATH_MAX) -# define PATH_MAX MAX_PATH -# endif -#else -#include -#endif -#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 - -// -// CURL utils -// - -using curl_ptr = std::unique_ptr; - -// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one -struct curl_slist_ptr { - struct curl_slist * ptr = nullptr; - ~curl_slist_ptr() { - if (ptr) { - curl_slist_free_all(ptr); - } - } -}; -#endif // LLAMA_USE_CURL - -using json = nlohmann::ordered_json; - // // CPU utils // @@ -483,6 +443,11 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } +std::string regex_escape(const std::string & s) { + static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); + return std::regex_replace(s, special_chars, "\\$0"); +} + std::string string_join(const std::vector & values, const std::string & separator) { std::ostringstream result; for (size_t i = 0; i < values.size(); ++i) { @@ -865,7 +830,7 @@ std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#ifdef __linux__ +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); } else { @@ -875,7 +840,9 @@ std::string fs_get_cache_directory() { cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); #elif defined(_WIN32) cache_directory = std::getenv("LOCALAPPDATA"); -#endif // __linux__ +#else +# error Unknown architecture +#endif cache_directory = ensure_trailing_slash(cache_directory); cache_directory += "llama.cpp"; } @@ -896,22 +863,14 @@ std::string fs_get_cache_file(const std::string & filename) { // // Model utils // + struct common_init_result common_init_from_params(common_params & params) { common_init_result iparams; auto mparams = common_model_params_to_llama(params); - llama_model * model = nullptr; - - if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams); - } else if (!params.model_url.empty()) { - model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams); - } else { - model = llama_model_load_from_file(params.model.c_str(), mparams); - } - + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); if (model == NULL) { - LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str()); + LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); return iparams; } @@ -946,13 +905,13 @@ struct common_init_result common_init_from_params(common_params & params) { llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { - LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); + LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); llama_model_free(model); return iparams; } - if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) { - LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__); + if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) { + LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); params.ctx_shift = false; } @@ -1029,6 +988,8 @@ struct common_init_result common_init_from_params(common_params & params) { if (params.warmup) { LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); + llama_set_warmup(lctx, true); + std::vector tmp; llama_token bos = llama_vocab_bos(vocab); llama_token eos = llama_vocab_eos(vocab); @@ -1056,9 +1017,10 @@ struct common_init_result common_init_from_params(common_params & params) { if (llama_model_has_decoder(model)) { llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); } - llama_kv_cache_clear(lctx); + llama_kv_self_clear(lctx); llama_synchronize(lctx); llama_perf_context_reset(lctx); + llama_set_warmup(lctx, false); } iparams.model.reset(model); @@ -1067,6 +1029,19 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } +std::string get_model_endpoint() { + const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); + // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. + const char * hf_endpoint_env = getenv("HF_ENDPOINT"); + const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env; + std::string model_endpoint = "https://huggingface.co/"; + if (endpoint_env) { + model_endpoint = endpoint_env; + if (model_endpoint.back() != '/') model_endpoint += '/'; + } + return model_endpoint; +} + void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { llama_clear_adapter_lora(ctx); for (auto & la : lora) { @@ -1082,15 +1057,18 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { if (!params.devices.empty()) { mparams.devices = params.devices.data(); } + if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } + mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; + if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { @@ -1098,6 +1076,13 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.kv_overrides = params.kv_overrides.data(); } + if (params.tensor_buft_overrides.empty()) { + mparams.tensor_buft_overrides = NULL; + } else { + GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern"); + mparams.tensor_buft_overrides = params.tensor_buft_overrides.data(); + } + return mparams; } @@ -1157,451 +1142,6 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p return tpp; } -#ifdef LLAMA_USE_CURL - -#define CURL_MAX_RETRY 3 -#define CURL_RETRY_DELAY_SECONDS 2 - -static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) { - int remaining_attempts = max_attempts; - - while (remaining_attempts > 0) { - LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); - - CURLcode res = curl_easy_perform(curl); - if (res == CURLE_OK) { - return true; - } - - int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000; - LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); - - remaining_attempts--; - std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); - } - - LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); - - return false; -} - -static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { - // Initialize libcurl - curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); - curl_slist_ptr http_headers; - if (!curl) { - LOG_ERR("%s: error initializing libcurl\n", __func__); - return false; - } - - bool force_download = false; - - // Set the URL, allow to follow http redirection - curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); - - // Check if hf-token or bearer-token was specified - if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer " + hf_token; - http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); - } - -#if defined(_WIN32) - // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of - // operating system. Currently implemented under MS-Windows. - curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - - // Check if the file already exists locally - auto file_exists = std::filesystem::exists(path); - - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; - nlohmann::json metadata; - std::string etag; - std::string last_modified; - - if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); - if (metadata_in.good()) { - try { - metadata_in >> metadata; - LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); - if (metadata.contains("url") && metadata.at("url").is_string()) { - auto previous_url = metadata.at("url").get(); - if (previous_url != url) { - LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); - return false; - } - } - if (metadata.contains("etag") && metadata.at("etag").is_string()) { - etag = metadata.at("etag"); - } - if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { - last_modified = metadata.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); - return false; - } - } - } else { - LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); - } - - // Send a HEAD request to retrieve the etag and last-modified headers - struct common_load_model_from_url_headers { - std::string etag; - std::string last_modified; - }; - - common_load_model_from_url_headers headers; - - { - typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); - auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { - common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; - - static std::regex header_regex("([^:]+): (.*)\r\n"); - static std::regex etag_regex("ETag", std::regex_constants::icase); - static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - - std::string header(buffer, n_items); - std::smatch match; - if (std::regex_match(header, match, header_regex)) { - const std::string & key = match[1]; - const std::string & value = match[2]; - if (std::regex_match(key, match, etag_regex)) { - headers->etag = value; - } else if (std::regex_match(key, match, last_modified_regex)) { - headers->last_modified = value; - } - } - return n_items; - }; - - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress - curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); - curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); - if (!was_perform_successful) { - return false; - } - - long http_code = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code != 200) { - // HEAD not supported, we don't know if the file has changed - // force trigger downloading - force_download = true; - LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); - } - } - - bool should_download = !file_exists || force_download; - if (!should_download) { - if (!etag.empty() && etag != headers.etag) { - LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); - should_download = true; - } else if (!last_modified.empty() && last_modified != headers.last_modified) { - LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); - should_download = true; - } - } - if (should_download) { - std::string path_temporary = path + ".downloadInProgress"; - if (file_exists) { - LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); - if (remove(path.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return false; - } - } - - // Set the output file - - struct FILE_deleter { - void operator()(FILE * f) const { - fclose(f); - } - }; - - std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); - if (!outfile) { - LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); - return false; - } - - typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); - auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { - return fwrite(data, size, nmemb, (FILE *)fd); - }; - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L); - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); - - // display download progress - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); - - // helper function to hide password in URL - auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { - std::size_t protocol_pos = url.find("://"); - if (protocol_pos == std::string::npos) { - return url; // Malformed URL - } - - std::size_t at_pos = url.find('@', protocol_pos + 3); - if (at_pos == std::string::npos) { - return url; // No password in URL - } - - return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos); - }; - - // start the download - LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, - llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); - if (!was_perform_successful) { - return false; - } - - long http_code = 0; - curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code < 200 || http_code >= 400) { - LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); - return false; - } - - // Causes file to be closed explicitly here before we rename it. - outfile.reset(); - - // Write the updated JSON metadata file. - metadata.update({ - {"url", url}, - {"etag", headers.etag}, - {"lastModified", headers.last_modified} - }); - std::ofstream(metadata_path) << metadata.dump(4); - LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); - - if (rename(path_temporary.c_str(), path.c_str()) != 0) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return false; - } - } - - return true; -} - -struct llama_model * common_load_model_from_url( - const std::string & model_url, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params) { - // Basic validation of the model_url - if (model_url.empty()) { - LOG_ERR("%s: invalid model_url\n", __func__); - return NULL; - } - - if (!common_download_file(model_url, local_path, hf_token)) { - return NULL; - } - - // check for additional GGUFs split to download - int n_split = 0; - { - struct gguf_init_params gguf_params = { - /*.no_alloc = */ true, - /*.ctx = */ NULL, - }; - auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params); - if (!ctx_gguf) { - LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str()); - return NULL; - } - - auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT); - if (key_n_split >= 0) { - n_split = gguf_get_val_u16(ctx_gguf, key_n_split); - } - - gguf_free(ctx_gguf); - } - - if (n_split > 1) { - char split_prefix[PATH_MAX] = {0}; - char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0}; - - // Verify the first split file format - // and extract split URL and PATH prefixes - { - if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split); - return NULL; - } - - if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split); - return NULL; - } - } - - // Prepare download in parallel - std::vector> futures_download; - for (int idx = 1; idx < n_split; idx++) { - futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool { - char split_path[PATH_MAX] = {0}; - llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split); - - char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; - llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); - - return common_download_file(split_url, split_path, hf_token); - }, idx)); - } - - // Wait for all downloads to complete - for (auto & f : futures_download) { - if (!f.get()) { - return NULL; - } - } - } - - return llama_model_load_from_file(local_path.c_str(), params); -} - -struct llama_model * common_load_model_from_hf( - const std::string & repo, - const std::string & remote_path, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params) { - // construct hugging face model url: - // - // --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf - // https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf - // - // --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf - // https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf - // - - std::string model_url = "https://huggingface.co/"; - model_url += repo; - model_url += "/resolve/main/"; - model_url += remote_path; - - return common_load_model_from_url(model_url, local_path, hf_token, params); -} - -/** - * Allow getting the HF file from the HF repo with tag (like ollama), for example: - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 - * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s - * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) - * - * Return pair of (with "repo" already having tag removed) - * - * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. - */ -std::pair common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) { - auto parts = string_split(hf_repo_with_tag, ':'); - std::string tag = parts.size() > 1 ? parts.back() : "latest"; - std::string hf_repo = parts[0]; - if (string_split(hf_repo, '/').size() != 2) { - throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); - } - - // fetch model info from Hugging Face Hub API - json model_info; - curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); - curl_slist_ptr http_headers; - std::string res_str; - std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; - curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); - typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); - auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { - static_cast(data)->append((char * ) ptr, size * nmemb); - return size * nmemb; - }; - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); -#if defined(_WIN32) - curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer " + hf_token; - http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); - } - // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response - http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); - http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); - - CURLcode res = curl_easy_perform(curl.get()); - - if (res != CURLE_OK) { - throw std::runtime_error("error: cannot make GET request to HF API"); - } - - long res_code; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); - if (res_code == 200) { - model_info = json::parse(res_str); - } else if (res_code == 401) { - throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); - } else { - throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); - } - - // check response - if (!model_info.contains("ggufFile")) { - throw std::runtime_error("error: model does not have ggufFile"); - } - json & gguf_file = model_info.at("ggufFile"); - if (!gguf_file.contains("rfilename")) { - throw std::runtime_error("error: ggufFile does not have rfilename"); - } - - return std::make_pair(hf_repo, gguf_file.at("rfilename")); -} - -#else - -struct llama_model * common_load_model_from_url( - const std::string & /*model_url*/, - const std::string & /*local_path*/, - const std::string & /*hf_token*/, - const struct llama_model_params & /*params*/) { - LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); - return nullptr; -} - -struct llama_model * common_load_model_from_hf( - const std::string & /*repo*/, - const std::string & /*remote_path*/, - const std::string & /*local_path*/, - const std::string & /*hf_token*/, - const struct llama_model_params & /*params*/) { - LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); - return nullptr; -} - -std::pair common_get_hf_file(const std::string &, const std::string &) { - LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); - return std::make_pair("", ""); -} - -#endif // LLAMA_USE_CURL - // // Batch utils // @@ -2025,4 +1565,3 @@ common_control_vector_data common_control_vector_load(const std::vector grammar_trigger_words; // optional trigger words to trigger lazy grammar - std::vector grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens. + std::vector grammar_triggers; // optional triggers (for lazy grammars) std::set preserved_tokens; std::vector logit_bias; // logit biases to apply @@ -173,6 +180,13 @@ struct common_params_sampling { std::string print() const; }; +struct common_params_model { + std::string path = ""; // model local path // NOLINT + std::string url = ""; // model url to download // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT +}; + struct common_params_speculative { std::vector devices; // devices to use for offloading @@ -186,19 +200,13 @@ struct common_params_speculative { struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT - - std::string model = ""; // draft model for speculative decoding // NOLINT - std::string model_url = ""; // model url to download // NOLINT + struct common_params_model model; }; struct common_params_vocoder { - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT + struct common_params_model model; - std::string model = ""; // model path // NOLINT - std::string model_url = ""; // model url to download // NOLINT + std::string speaker_file = ""; // speaker file path // NOLINT bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; @@ -254,13 +262,12 @@ struct common_params { struct common_params_speculative speculative; struct common_params_vocoder vocoder; - std::string model = ""; // model path // NOLINT + struct common_params_model model; + std::string model_alias = ""; // model alias // NOLINT - std::string model_url = ""; // model url to download // NOLINT std::string hf_token = ""; // HF token // NOLINT - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT std::string prompt = ""; // NOLINT + std::string system_prompt = ""; // NOLINT std::string prompt_file = ""; // store the external prompt file name // NOLINT std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT std::string input_prefix = ""; // string to prefix user inputs with // NOLINT @@ -272,6 +279,7 @@ struct common_params { std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; + std::vector tensor_buft_overrides; bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) std::vector lora_adapters; // lora adapter path with user defined scale @@ -325,13 +333,15 @@ struct common_params { bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data + bool single_turn = false; // single turn chat conversation + ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector // NOLINT + struct common_params_model mmproj; std::vector image; // path to image file(s) // embedding @@ -391,8 +401,6 @@ struct common_params { int32_t i_pos = -1; // position of the passkey in the junk text // imatrix params - std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file - int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations int32_t i_chunk = 0; // start processing from this chunk @@ -404,16 +412,16 @@ struct common_params { int n_pca_batch = 100; int n_pca_iterations = 1000; dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; - std::string cvector_outfile = "control_vector.gguf"; std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; bool spm_infill = false; // suffix/prefix/middle pattern for infill - std::string lora_outfile = "ggml-lora-merged-f16.gguf"; - // batched-bench params bool batched_bench_output_jsonl = false; + + // common params + std::string out_file; // output filename for all example programs }; // call once at the start of a program if it uses libcommon @@ -453,6 +461,8 @@ std::string string_repeat(const std::string & str, size_t n); void string_replace_all(std::string & s, const std::string & search, const std::string & replace); +std::string regex_escape(const std::string & s); + template static std::vector string_split(const std::string & str, char delim) { static_assert(!std::is_same::value, "Please use the specialized version for std::string"); @@ -530,26 +540,11 @@ struct llama_model_params common_model_params_to_llama ( common_params struct llama_context_params common_context_params_to_llama(const common_params & params); struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); -struct llama_model * common_load_model_from_url( - const std::string & model_url, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params); - -struct llama_model * common_load_model_from_hf( - const std::string & repo, - const std::string & remote_path, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params); - -std::pair common_get_hf_file( - const std::string & hf_repo_with_tag, - const std::string & hf_token); - // clear LoRA adapters from context, then apply new list of adapters void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); +std::string get_model_endpoint(); + // // Batch utils // diff --git a/llama/llama.cpp/common/json-schema-to-grammar.cpp b/llama/llama.cpp/common/json-schema-to-grammar.cpp index 30c28808f..56043678c 100644 --- a/llama/llama.cpp/common/json-schema-to-grammar.cpp +++ b/llama/llama.cpp/common/json-schema-to-grammar.cpp @@ -264,7 +264,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & throw std::runtime_error("At least one of min_value or max_value must be set"); } -const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}"; +const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}"; struct BuiltinRule { std::string content; @@ -764,11 +764,10 @@ private: public: SchemaConverter( const std::function & fetch_json, - bool dotall, - bool compact_spaces) + bool dotall) : _fetch_json(fetch_json), _dotall(dotall) { - _rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE; + _rules["space"] = SPACE_RULE; } void resolve_refs(json & schema, const std::string & url) { @@ -1007,7 +1006,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { } std::string build_grammar(const std::function & cb, const common_grammar_options & options) { - SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces); + SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall); common_grammar_builder builder { /* .add_rule = */ [&](const std::string & name, const std::string & rule) { return converter._add_rule(name, rule); diff --git a/llama/llama.cpp/common/json-schema-to-grammar.h b/llama/llama.cpp/common/json-schema-to-grammar.h index 62a3b0a44..4613f5d9f 100644 --- a/llama/llama.cpp/common/json-schema-to-grammar.h +++ b/llama/llama.cpp/common/json-schema-to-grammar.h @@ -16,7 +16,6 @@ struct common_grammar_builder { struct common_grammar_options { bool dotall = false; - bool compact_spaces = false; }; std::string build_grammar(const std::function & cb, const common_grammar_options & options = {}); diff --git a/llama/llama.cpp/common/sampling.cpp b/llama/llama.cpp/common/sampling.cpp index 37a0d9c85..1735b6501 100644 --- a/llama/llama.cpp/common/sampling.cpp +++ b/llama/llama.cpp/common/sampling.cpp @@ -4,6 +4,7 @@ #include #include +#include // the ring buffer works similarly to std::deque, but with a fixed capacity // TODO: deduplicate with llama-impl.h @@ -159,17 +160,57 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { - std::vector trigger_words; - trigger_words.reserve(params.grammar_trigger_words.size()); - for (const auto & str : params.grammar_trigger_words) { - trigger_words.push_back(str.word.c_str()); + std::vector patterns_at_start; + std::vector patterns_anywhere; + std::vector trigger_tokens; + for (const auto & trigger : params.grammar_triggers) { + switch (trigger.type) { + case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: + { + const auto & word = trigger.value; + patterns_anywhere.push_back(regex_escape(word)); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: + { + const auto & pattern = trigger.value; + (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: + { + const auto token = trigger.token; + trigger_tokens.push_back(token); + break; + } + default: + GGML_ASSERT(false && "unknown trigger type"); + } + } + + std::vector trigger_patterns; + if (!patterns_at_start.empty()) { + trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*"); + } + if (!patterns_anywhere.empty()) { + trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); + } + + std::vector trigger_patterns_c; + trigger_patterns_c.reserve(trigger_patterns.size()); + for (const auto & regex : trigger_patterns) { + trigger_patterns_c.push_back(regex.c_str()); } grmr = params.grammar_lazy - ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", - trigger_words.data(), trigger_words.size(), - params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) + ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", + trigger_patterns_c.data(), trigger_patterns_c.size(), + trigger_tokens.data(), trigger_tokens.size()) : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); + if (!grmr) { + return nullptr; + } } auto * result = new common_sampler { diff --git a/llama/llama.cpp/examples/llava/clip-impl.h b/llama/llama.cpp/examples/llava/clip-impl.h new file mode 100644 index 000000000..4d7340a56 --- /dev/null +++ b/llama/llama.cpp/examples/llava/clip-impl.h @@ -0,0 +1,344 @@ +#include "ggml.h" +#include "gguf.h" +#include "clip.h" + +#include "clip.h" + +#include +#include +#include +#include +#include +#include +#include + +// Internal header for clip.cpp + +#define KEY_FTYPE "general.file_type" +#define KEY_NAME "general.name" +#define KEY_DESCRIPTION "general.description" +#define KEY_HAS_TEXT_ENC "clip.has_text_encoder" +#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" +#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" +#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" +#define KEY_HAS_GLM_PROJ "clip.has_glm_projector" +#define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" +#define KEY_USE_GELU "clip.use_gelu" +#define KEY_USE_SILU "clip.use_silu" +#define KEY_N_EMBD "clip.%s.embedding_length" +#define KEY_N_FF "clip.%s.feed_forward_length" +#define KEY_N_BLOCK "clip.%s.block_count" +#define KEY_N_HEAD "clip.%s.attention.head_count" +#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" +#define KEY_PROJ_DIM "clip.%s.projection_dim" +#define KEY_TOKENS "tokenizer.ggml.tokens" +#define KEY_N_POSITIONS "clip.text.context_length" +#define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_PATCH_SIZE "clip.vision.patch_size" +#define KEY_IMAGE_MEAN "clip.vision.image_mean" +#define KEY_IMAGE_STD "clip.vision.image_std" +#define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_FEATURE_LAYER "clip.vision.feature_layer" + +#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" +#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" +#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" + + +// +// tensor name constants +// + +#define TN_TOKEN_EMBD "%s.token_embd.weight" +#define TN_POS_EMBD "%s.position_embd.weight" +#define TN_CLASS_EMBD "v.class_embd" +#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat +#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" +#define TN_PATCH_BIAS "v.patch_embd.bias" +#define TN_ATTN_K "%s.blk.%d.attn_k.%s" +#define TN_ATTN_Q "%s.blk.%d.attn_q.%s" +#define TN_ATTN_V "%s.blk.%d.attn_v.%s" +#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" +#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" +#define TN_FFN_UP "%s.blk.%d.ffn_up.%s" +#define TN_LN_1 "%s.blk.%d.ln1.%s" +#define TN_LN_2 "%s.blk.%d.ln2.%s" +#define TN_LN_PRE "%s.pre_ln.%s" +#define TN_LN_POST "%s.post_ln.%s" +#define TN_TEXT_PROJ "text_projection.weight" +#define TN_VIS_PROJ "visual_projection.weight" +#define TN_LLAVA_PROJ "mm.%d.%s" +#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s" +#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" +#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" +#define TN_IMAGE_NEWLINE "model.image_newline" +#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 +#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 + +// mimicpmv +#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" +#define TN_MINICPMV_QUERY "resampler.query" +#define TN_MINICPMV_PROJ "resampler.proj.weight" +#define TN_MINICPMV_KV_PROJ "resampler.kv.weight" +#define TN_MINICPMV_ATTN "resampler.attn.%s.%s" +#define TN_MINICPMV_LN "resampler.ln_%s.%s" + +#define TN_GLM_ADAPER_CONV "adapter.conv.%s" +#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s" +#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s" +#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s" +#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" +#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" +#define TN_GLM_BOI_W "adapter.boi" +#define TN_GLM_EOI_W "adapter.eoi" + +enum projector_type { + PROJECTOR_TYPE_MLP, + PROJECTOR_TYPE_MLP_NORM, + PROJECTOR_TYPE_LDP, + PROJECTOR_TYPE_LDPV2, + PROJECTOR_TYPE_RESAMPLER, + PROJECTOR_TYPE_GLM_EDGE, + PROJECTOR_TYPE_MERGER, + PROJECTOR_TYPE_GEMMA3, + PROJECTOR_TYPE_UNKNOWN, +}; + +static std::map PROJECTOR_TYPE_NAMES = { + { PROJECTOR_TYPE_MLP, "mlp" }, + { PROJECTOR_TYPE_LDP, "ldp" }, + { PROJECTOR_TYPE_LDPV2, "ldpv2"}, + { PROJECTOR_TYPE_RESAMPLER, "resampler"}, + { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, + { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, + { PROJECTOR_TYPE_GEMMA3, "gemma3"}, +}; + +static projector_type clip_projector_type_from_string(const std::string & str) { + for (const auto & pair : PROJECTOR_TYPE_NAMES) { + if (pair.second == str) { + return pair.first; + } + } + return PROJECTOR_TYPE_UNKNOWN; +} + +// RGB uint8 image +struct clip_image_u8 { + int nx; + int ny; + + std::vector buf; +}; + +// RGB float32 image (NHWC) +// Memory layout: RGBRGBRGB... +struct clip_image_f32 { + int nx; + int ny; + + std::vector buf; +}; + +// +// logging +// + +static void clip_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct clip_logger_state { + ggml_log_level verbosity_thold; + ggml_log_callback log_callback; + void * log_callback_user_data; +}; + +extern struct clip_logger_state g_logger_state; + +static void clip_log_internal_v(enum ggml_log_level level, const char * format, va_list args) { + if (format == NULL) { + return; + } + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data); + } else { + char * buffer2 = (char *) calloc(len + 1, sizeof(char)); + vsnprintf(buffer2, len + 1, format, args_copy); + buffer2[len] = 0; + g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data); + free(buffer2); + } + va_end(args_copy); +} + +static void clip_log_internal(enum ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + clip_log_internal_v(level, format, args); + va_end(args); +} + +#define LOG_TMPL(level, ...) \ + do { \ + if ((level) >= g_logger_state.verbosity_thold) { \ + clip_log_internal((level), __VA_ARGS__); \ + } \ + } while (0) +#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, __VA_ARGS__) +#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, __VA_ARGS__) +#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__) + +// +// cpp wrappers +// + +// wrapper for clip_image_size +struct clip_image_size_deleter { + void operator()(clip_image_size * val) { clip_image_size_free(val); } +}; +typedef std::unique_ptr clip_image_size_ptr; + +// wrapper for clip_image_u8 +struct clip_image_u8_deleter { + void operator()(clip_image_u8 * val) { clip_image_u8_free(val); } +}; +typedef std::unique_ptr clip_image_u8_ptr; + +// wrapper for clip_image_f32 +struct clip_image_f32_deleter { + void operator()(clip_image_f32 * val) { clip_image_f32_free(val); } +}; +typedef std::unique_ptr clip_image_f32_ptr; + +struct clip_image_u8_batch { + std::vector entries; +}; + +struct clip_image_f32_batch { + std::vector entries; +}; + +// +// common utils +// + +static std::string string_format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), buf.size()); +} + +static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +// split string by a `std::string delim` instead of `char delim` +static std::vector string_split_str(std::string s, const std::string & delimiter) { + std::vector tokens; + size_t pos = 0; + std::string token; + while ((pos = s.find(delimiter)) != std::string::npos) { + token = s.substr(0, pos); + tokens.push_back(token); + s.erase(0, pos + delimiter.length()); + } + tokens.push_back(s); + return tokens; +} + +// +// gguf utils +// + +static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) { + switch (type) { + case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]); + case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]); + case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]); + case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]); + case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]); + case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]); + case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]); + case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); + case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); + case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); + case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + default: return string_format("unknown type %d", type); + } +} + +static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + + switch (type) { + case GGUF_TYPE_STRING: + return gguf_get_val_str(ctx_gguf, i); + case GGUF_TYPE_ARRAY: + { + const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); + int arr_n = gguf_get_arr_n(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); + std::stringstream ss; + ss << "["; + for (int j = 0; j < arr_n; j++) { + if (arr_type == GGUF_TYPE_STRING) { + std::string val = gguf_get_arr_str(ctx_gguf, i, j); + // escape quotes + string_replace_all(val, "\\", "\\\\"); + string_replace_all(val, "\"", "\\\""); + ss << '"' << val << '"'; + } else if (arr_type == GGUF_TYPE_ARRAY) { + ss << "???"; + } else { + ss << gguf_data_to_str(arr_type, data, j); + } + if (j < arr_n - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + default: + return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0); + } +} + +// +// API used internally with mtmd +// + +projector_type clip_get_projector_type(const struct clip_ctx * ctx); diff --git a/llama/llama.cpp/examples/llava/clip.cpp b/llama/llama.cpp/examples/llava/clip.cpp index 54265bebb..4b72ea9fd 100644 --- a/llama/llama.cpp/examples/llava/clip.cpp +++ b/llama/llama.cpp/examples/llava/clip.cpp @@ -3,32 +3,14 @@ // I'll gradually clean and extend it // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch #include "clip.h" +#include "clip-impl.h" #include "ggml.h" +#include "ggml-cpp.h" #include "ggml-cpu.h" #include "ggml-alloc.h" #include "ggml-backend.h" #include "gguf.h" -#ifdef GGML_USE_CUDA -#include "ggml-cuda.h" -#endif - -#ifdef GGML_USE_SYCL -#include "ggml-sycl.h" -#endif - -#ifdef GGML_USE_METAL -#include "ggml-metal.h" -#endif - -#ifdef GGML_USE_CANN -#include "ggml-cann.h" -#endif - -#ifdef GGML_USE_VULKAN -#include "ggml-vulkan.h" -#endif - #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" @@ -46,18 +28,6 @@ #include #include -#if defined(LLAVA_LOG_OFF) -# define LOG_INF(...) -# define LOG_WRN(...) -# define LOG_ERR(...) -# define LOG_DBG(...) -#else // defined(LLAVA_LOG_OFF) -# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -#endif // defined(LLAVA_LOG_OFF) - #if defined(_WIN32) #define WIN32_LEAN_AND_MEAN #ifndef NOMINMAX @@ -71,268 +41,10 @@ #endif #endif +struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; + //#define CLIP_DEBUG_FUNCTIONS -// RGB uint8 image -struct clip_image_u8 { - int nx; - int ny; - - std::vector buf; -}; - -// RGB float32 image (NHWC) -// Memory layout: RGBRGBRGB... -struct clip_image_f32 { - int nx; - int ny; - - std::vector buf; -}; - -static std::string format(const char * fmt, ...) { - va_list ap; - va_list ap2; - va_start(ap, fmt); - va_copy(ap2, ap); - int size = vsnprintf(NULL, 0, fmt, ap); - GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT - std::vector buf(size + 1); - int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); - GGML_ASSERT(size2 == size); - va_end(ap2); - va_end(ap); - return std::string(buf.data(), buf.size()); -} - -// -// key constants -// - -#define KEY_FTYPE "general.file_type" -#define KEY_NAME "general.name" -#define KEY_DESCRIPTION "general.description" -#define KEY_HAS_TEXT_ENC "clip.has_text_encoder" -#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" -#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" -#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" -#define KEY_HAS_GLM_PROJ "clip.has_glm_projector" -#define KEY_MINICPMV_VERSION "clip.minicpmv_version" -#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" -#define KEY_USE_GELU "clip.use_gelu" -#define KEY_USE_SILU "clip.use_silu" -#define KEY_N_EMBD "clip.%s.embedding_length" -#define KEY_N_FF "clip.%s.feed_forward_length" -#define KEY_N_BLOCK "clip.%s.block_count" -#define KEY_N_HEAD "clip.%s.attention.head_count" -#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" -#define KEY_PROJ_DIM "clip.%s.projection_dim" -#define KEY_TOKENS "tokenizer.ggml.tokens" -#define KEY_N_POSITIONS "clip.text.context_length" -#define KEY_IMAGE_SIZE "clip.vision.image_size" -#define KEY_PATCH_SIZE "clip.vision.patch_size" -#define KEY_IMAGE_MEAN "clip.vision.image_mean" -#define KEY_IMAGE_STD "clip.vision.image_std" -#define KEY_PROJ_TYPE "clip.projector_type" -#define KEY_FEATURE_LAYER "clip.vision.feature_layer" - -#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" -#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" -#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" - - -// -// tensor name constants -// - -#define TN_TOKEN_EMBD "%s.token_embd.weight" -#define TN_POS_EMBD "%s.position_embd.weight" -#define TN_CLASS_EMBD "v.class_embd" -#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat -#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" -#define TN_PATCH_BIAS "v.patch_embd.bias" -#define TN_ATTN_K "%s.blk.%d.attn_k.%s" -#define TN_ATTN_Q "%s.blk.%d.attn_q.%s" -#define TN_ATTN_V "%s.blk.%d.attn_v.%s" -#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" -#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" -#define TN_FFN_UP "%s.blk.%d.ffn_up.%s" -#define TN_LN_1 "%s.blk.%d.ln1.%s" -#define TN_LN_2 "%s.blk.%d.ln2.%s" -#define TN_LN_PRE "%s.pre_ln.%s" -#define TN_LN_POST "%s.post_ln.%s" -#define TN_TEXT_PROJ "text_projection.weight" -#define TN_VIS_PROJ "visual_projection.weight" -#define TN_LLAVA_PROJ "mm.%d.%s" -#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s" -#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" -#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" -#define TN_IMAGE_NEWLINE "model.image_newline" - -#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" -#define TN_MINICPMV_QUERY "resampler.query" -#define TN_MINICPMV_PROJ "resampler.proj.weight" -#define TN_MINICPMV_KV_PROJ "resampler.kv.weight" -#define TN_MINICPMV_ATTN "resampler.attn.%s.%s" -#define TN_MINICPMV_LN "resampler.ln_%s.%s" - -#define TN_GLM_ADAPER_CONV "adapter.conv.%s" -#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s" -#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s" -#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s" -#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" -#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" -#define TN_GLM_BOI_W "adapter.boi" -#define TN_GLM_EOI_W "adapter.eoi" - - -enum projector_type { - PROJECTOR_TYPE_MLP, - PROJECTOR_TYPE_MLP_NORM, - PROJECTOR_TYPE_LDP, - PROJECTOR_TYPE_LDPV2, - PROJECTOR_TYPE_RESAMPLER, - PROJECTOR_TYPE_GLM_EDGE, - PROJECTOR_TYPE_MERGER, - PROJECTOR_TYPE_UNKNOWN, -}; - -static std::map PROJECTOR_TYPE_NAMES = { - { PROJECTOR_TYPE_MLP, "mlp" }, - { PROJECTOR_TYPE_LDP, "ldp" }, - { PROJECTOR_TYPE_LDPV2, "ldpv2"}, - { PROJECTOR_TYPE_RESAMPLER, "resampler"}, - { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, - { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, -}; - - -// -// utilities to get data from a gguf file -// - -static int get_key_idx(const gguf_context * ctx, const char * key) { - int i = gguf_find_key(ctx, key); - if (i == -1) { - LOG_ERR("key %s not found in file\n", key); - throw std::runtime_error(format("Missing required key: %s", key)); - } - - return i; -} - -static uint32_t get_u32(const gguf_context * ctx, const std::string & key) { - const int i = get_key_idx(ctx, key.c_str()); - - return gguf_get_val_u32(ctx, i); -} - -static float get_f32(const gguf_context * ctx, const std::string & key) { - const int i = get_key_idx(ctx, key.c_str()); - - return gguf_get_val_f32(ctx, i); -} - -static struct ggml_tensor * get_tensor(struct ggml_context * ctx, const std::string & name) { - struct ggml_tensor * cur = ggml_get_tensor(ctx, name.c_str()); - if (!cur) { - throw std::runtime_error(format("%s: unable to find tensor %s\n", __func__, name.c_str())); - } - - return cur; -} - -static std::string get_ftype(int ftype) { - return ggml_type_name(static_cast(ftype)); -} - -static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) { - switch (type) { - case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]); - case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]); - case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]); - case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]); - case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]); - case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]); - case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]); - case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); - case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); - case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); - case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; - default: return format("unknown type %d", type); - } -} - -static void replace_all(std::string & s, const std::string & search, const std::string & replace) { - if (search.empty()) { - return; - } - std::string builder; - builder.reserve(s.length()); - size_t pos = 0; - size_t last_pos = 0; - while ((pos = s.find(search, last_pos)) != std::string::npos) { - builder.append(s, last_pos, pos - last_pos); - builder.append(replace); - last_pos = pos + search.length(); - } - builder.append(s, last_pos, std::string::npos); - s = std::move(builder); -} - -static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { - const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); - - switch (type) { - case GGUF_TYPE_STRING: - return gguf_get_val_str(ctx_gguf, i); - case GGUF_TYPE_ARRAY: - { - const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); - int arr_n = gguf_get_arr_n(ctx_gguf, i); - const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); - std::stringstream ss; - ss << "["; - for (int j = 0; j < arr_n; j++) { - if (arr_type == GGUF_TYPE_STRING) { - std::string val = gguf_get_arr_str(ctx_gguf, i, j); - // escape quotes - replace_all(val, "\\", "\\\\"); - replace_all(val, "\"", "\\\""); - ss << '"' << val << '"'; - } else if (arr_type == GGUF_TYPE_ARRAY) { - ss << "???"; - } else { - ss << gguf_data_to_str(arr_type, data, j); - } - if (j < arr_n - 1) { - ss << ", "; - } - } - ss << "]"; - return ss.str(); - } - default: - return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0); - } -} - -static void print_tensor_info(const ggml_tensor * tensor, const char * prefix = "") { - size_t tensor_size = ggml_nbytes(tensor); - LOG_INF("%s: n_dims = %d, name = %s, tensor_size=%zu, shape:[%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "], type = %s\n", - prefix, ggml_n_dims(tensor), tensor->name, tensor_size, - tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], ggml_type_name(tensor->type)); -} - -static projector_type clip_projector_type_from_string(const std::string & name) { - for (const auto & kv : PROJECTOR_TYPE_NAMES) { // NOLINT - if (kv.second == name) { - return kv.first; - } - } - return PROJECTOR_TYPE_UNKNOWN; -} - #ifdef CLIP_DEBUG_FUNCTIONS static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) { std::ofstream file(filename, std::ios::binary); @@ -446,6 +158,11 @@ static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u // clip layers // +enum patch_merge_type { + PATCH_MERGE_FLAT, + PATCH_MERGE_SPATIAL_UNPAD, +}; + struct clip_hparams { int32_t image_size; int32_t patch_size; @@ -455,9 +172,9 @@ struct clip_hparams { int32_t n_head; int32_t n_layer; - float eps; + patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT; - char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default) + float eps; std::vector image_grid_pinpoints; int32_t image_crop_resolution; @@ -466,44 +183,44 @@ struct clip_hparams { struct clip_layer { // attention - struct ggml_tensor * k_w; - struct ggml_tensor * k_b; - struct ggml_tensor * q_w; - struct ggml_tensor * q_b; - struct ggml_tensor * v_w; - struct ggml_tensor * v_b; + struct ggml_tensor * k_w = nullptr; + struct ggml_tensor * k_b = nullptr; + struct ggml_tensor * q_w = nullptr; + struct ggml_tensor * q_b = nullptr; + struct ggml_tensor * v_w = nullptr; + struct ggml_tensor * v_b = nullptr; - struct ggml_tensor * o_w; - struct ggml_tensor * o_b; + struct ggml_tensor * o_w = nullptr; + struct ggml_tensor * o_b = nullptr; // layernorm 1 - struct ggml_tensor * ln_1_w; - struct ggml_tensor * ln_1_b; + struct ggml_tensor * ln_1_w = nullptr; + struct ggml_tensor * ln_1_b = nullptr; // ff - struct ggml_tensor * ff_i_w; - struct ggml_tensor * ff_i_b; + struct ggml_tensor * ff_i_w = nullptr; + struct ggml_tensor * ff_i_b = nullptr; - struct ggml_tensor * ff_o_w; - struct ggml_tensor * ff_o_b; + struct ggml_tensor * ff_o_w = nullptr; + struct ggml_tensor * ff_o_b = nullptr; // layernorm 2 - struct ggml_tensor * ln_2_w; - struct ggml_tensor * ln_2_b; + struct ggml_tensor * ln_2_w = nullptr; + struct ggml_tensor * ln_2_b = nullptr; }; struct clip_vision_model { struct clip_hparams hparams; // embeddings - struct ggml_tensor * class_embedding; - struct ggml_tensor * patch_embeddings_0; - struct ggml_tensor * patch_embeddings_1; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) - struct ggml_tensor * patch_bias; - struct ggml_tensor * position_embeddings; + struct ggml_tensor * class_embedding = nullptr; + struct ggml_tensor * patch_embeddings_0 = nullptr; + struct ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) + struct ggml_tensor * patch_bias = nullptr; + struct ggml_tensor * position_embeddings = nullptr; - struct ggml_tensor * pre_ln_w; - struct ggml_tensor * pre_ln_b; + struct ggml_tensor * pre_ln_w = nullptr; + struct ggml_tensor * pre_ln_b = nullptr; std::vector layers; @@ -513,80 +230,84 @@ struct clip_vision_model { struct ggml_tensor * projection; // LLaVA projection - struct ggml_tensor * mm_0_w = NULL; - struct ggml_tensor * mm_0_b = NULL; - struct ggml_tensor * mm_2_w = NULL; - struct ggml_tensor * mm_2_b = NULL; + struct ggml_tensor * mm_0_w = nullptr; + struct ggml_tensor * mm_0_b = nullptr; + struct ggml_tensor * mm_2_w = nullptr; + struct ggml_tensor * mm_2_b = nullptr; - struct ggml_tensor * image_newline = NULL; + struct ggml_tensor * image_newline = nullptr; // Yi type models with mlp+normalization projection - struct ggml_tensor * mm_1_w = NULL; // Yi type models have 0, 1, 3, 4 - struct ggml_tensor * mm_1_b = NULL; - struct ggml_tensor * mm_3_w = NULL; - struct ggml_tensor * mm_3_b = NULL; - struct ggml_tensor * mm_4_w = NULL; - struct ggml_tensor * mm_4_b = NULL; + struct ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4 + struct ggml_tensor * mm_1_b = nullptr; + struct ggml_tensor * mm_3_w = nullptr; + struct ggml_tensor * mm_3_b = nullptr; + struct ggml_tensor * mm_4_w = nullptr; + struct ggml_tensor * mm_4_b = nullptr; //GLMV-Edge projection - struct ggml_tensor * mm_model_adapter_conv_w; - struct ggml_tensor * mm_model_adapter_conv_b; - struct ggml_tensor * boi_w; - struct ggml_tensor * eoi_w; + struct ggml_tensor * mm_model_adapter_conv_w = nullptr; + struct ggml_tensor * mm_model_adapter_conv_b = nullptr; + struct ggml_tensor * boi_w = nullptr; + struct ggml_tensor * eoi_w = nullptr; // MobileVLM projection - struct ggml_tensor * mm_model_mlp_1_w; - struct ggml_tensor * mm_model_mlp_1_b; - struct ggml_tensor * mm_model_mlp_3_w; - struct ggml_tensor * mm_model_mlp_3_b; - struct ggml_tensor * mm_model_block_1_block_0_0_w; - struct ggml_tensor * mm_model_block_1_block_0_1_w; - struct ggml_tensor * mm_model_block_1_block_0_1_b; - struct ggml_tensor * mm_model_block_1_block_1_fc1_w; - struct ggml_tensor * mm_model_block_1_block_1_fc1_b; - struct ggml_tensor * mm_model_block_1_block_1_fc2_w; - struct ggml_tensor * mm_model_block_1_block_1_fc2_b; - struct ggml_tensor * mm_model_block_1_block_2_0_w; - struct ggml_tensor * mm_model_block_1_block_2_1_w; - struct ggml_tensor * mm_model_block_1_block_2_1_b; - struct ggml_tensor * mm_model_block_2_block_0_0_w; - struct ggml_tensor * mm_model_block_2_block_0_1_w; - struct ggml_tensor * mm_model_block_2_block_0_1_b; - struct ggml_tensor * mm_model_block_2_block_1_fc1_w; - struct ggml_tensor * mm_model_block_2_block_1_fc1_b; - struct ggml_tensor * mm_model_block_2_block_1_fc2_w; - struct ggml_tensor * mm_model_block_2_block_1_fc2_b; - struct ggml_tensor * mm_model_block_2_block_2_0_w; - struct ggml_tensor * mm_model_block_2_block_2_1_w; - struct ggml_tensor * mm_model_block_2_block_2_1_b; + struct ggml_tensor * mm_model_mlp_1_w = nullptr; + struct ggml_tensor * mm_model_mlp_1_b = nullptr; + struct ggml_tensor * mm_model_mlp_3_w = nullptr; + struct ggml_tensor * mm_model_mlp_3_b = nullptr; + struct ggml_tensor * mm_model_block_1_block_0_0_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_0_1_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_0_1_b = nullptr; + struct ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr; + struct ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr; + struct ggml_tensor * mm_model_block_1_block_2_0_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_2_1_w = nullptr; + struct ggml_tensor * mm_model_block_1_block_2_1_b = nullptr; + struct ggml_tensor * mm_model_block_2_block_0_0_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_0_1_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_0_1_b = nullptr; + struct ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr; + struct ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr; + struct ggml_tensor * mm_model_block_2_block_2_0_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_2_1_w = nullptr; + struct ggml_tensor * mm_model_block_2_block_2_1_b = nullptr; // MobileVLM_V2 projection - struct ggml_tensor * mm_model_mlp_0_w; - struct ggml_tensor * mm_model_mlp_0_b; - struct ggml_tensor * mm_model_mlp_2_w; - struct ggml_tensor * mm_model_mlp_2_b; - struct ggml_tensor * mm_model_peg_0_w; - struct ggml_tensor * mm_model_peg_0_b; + struct ggml_tensor * mm_model_mlp_0_w = nullptr; + struct ggml_tensor * mm_model_mlp_0_b = nullptr; + struct ggml_tensor * mm_model_mlp_2_w = nullptr; + struct ggml_tensor * mm_model_mlp_2_b = nullptr; + struct ggml_tensor * mm_model_peg_0_w = nullptr; + struct ggml_tensor * mm_model_peg_0_b = nullptr; // MINICPMV projection - struct ggml_tensor * mm_model_pos_embed_k; - struct ggml_tensor * mm_model_query; - struct ggml_tensor * mm_model_proj; - struct ggml_tensor * mm_model_kv_proj; - struct ggml_tensor * mm_model_attn_q_w; - struct ggml_tensor * mm_model_attn_q_b; - struct ggml_tensor * mm_model_attn_k_w; - struct ggml_tensor * mm_model_attn_k_b; - struct ggml_tensor * mm_model_attn_v_w; - struct ggml_tensor * mm_model_attn_v_b; - struct ggml_tensor * mm_model_attn_o_w; - struct ggml_tensor * mm_model_attn_o_b; - struct ggml_tensor * mm_model_ln_q_w; - struct ggml_tensor * mm_model_ln_q_b; - struct ggml_tensor * mm_model_ln_kv_w; - struct ggml_tensor * mm_model_ln_kv_b; - struct ggml_tensor * mm_model_ln_post_w; - struct ggml_tensor * mm_model_ln_post_b; + struct ggml_tensor * mm_model_pos_embed_k = nullptr; + struct ggml_tensor * mm_model_query = nullptr; + struct ggml_tensor * mm_model_proj = nullptr; + struct ggml_tensor * mm_model_kv_proj = nullptr; + struct ggml_tensor * mm_model_attn_q_w = nullptr; + struct ggml_tensor * mm_model_attn_q_b = nullptr; + struct ggml_tensor * mm_model_attn_k_w = nullptr; + struct ggml_tensor * mm_model_attn_k_b = nullptr; + struct ggml_tensor * mm_model_attn_v_w = nullptr; + struct ggml_tensor * mm_model_attn_v_b = nullptr; + struct ggml_tensor * mm_model_attn_o_w = nullptr; + struct ggml_tensor * mm_model_attn_o_b = nullptr; + struct ggml_tensor * mm_model_ln_q_w = nullptr; + struct ggml_tensor * mm_model_ln_q_b = nullptr; + struct ggml_tensor * mm_model_ln_kv_w = nullptr; + struct ggml_tensor * mm_model_ln_kv_b = nullptr; + struct ggml_tensor * mm_model_ln_post_w = nullptr; + struct ggml_tensor * mm_model_ln_post_b = nullptr; + + // gemma3 + struct ggml_tensor * mm_input_proj_w = nullptr; + struct ggml_tensor * mm_soft_emb_norm_w = nullptr; }; struct clip_ctx { @@ -601,33 +322,211 @@ struct clip_ctx { struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; - int32_t max_feature_layer; + int32_t max_feature_layer; // unused in newer models like gemma3 float image_mean[3]; float image_std[3]; bool use_gelu = false; bool use_silu = false; - int32_t ftype = 1; - bool has_class_embedding = true; - bool has_pre_norm = true; - bool has_post_norm = false; - bool has_patch_bias = false; - - struct gguf_context * ctx_gguf; - struct ggml_context * ctx_data; + gguf_context_ptr ctx_gguf; + ggml_context_ptr ctx_data; std::vector buf_compute_meta; - // memory buffers to evaluate the model - ggml_backend_buffer_t params_buffer = NULL; + std::vector backend_ptrs; + std::vector backend_buft; - ggml_backend_t backend = NULL; - ggml_gallocr_t compute_alloc = NULL; + ggml_backend_t backend; + ggml_backend_t backend_cpu; + ggml_backend_buffer_ptr buf; - struct clip_image_size * load_image_size; + ggml_backend_sched_ptr sched; + + clip_image_size load_image_size; + + clip_ctx(clip_context_params & ctx_params) { + backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + backend = ctx_params.use_gpu + ? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr) + : nullptr; + + if (backend) { + LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend)); + backend_ptrs.push_back(backend); + backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); + } else { + backend = backend_cpu; + LOG_INF("%s: CLIP using CPU backend\n", __func__); + } + + backend_ptrs.push_back(backend_cpu); + backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu)); + + sched.reset( + ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false) + ); + } + + ~clip_ctx() { + ggml_backend_free(backend); + if (backend != backend_cpu) { + ggml_backend_free(backend_cpu); + } + } }; -static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { +static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) { + const auto & model = ctx->vision_model; + const auto & hparams = model.hparams; + + const int image_size = hparams.image_size; + int image_size_width = image_size; + int image_size_height = image_size; + + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + const int n_layer = hparams.n_layer; + const float eps = hparams.eps; + + GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1 + + struct ggml_init_params params = { + /*.mem_size =*/ ctx->buf_compute_meta.size(), + /*.mem_buffer =*/ ctx->buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx0_ptr(ggml_init(params)); + auto ctx0 = ctx0_ptr.get(); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + // input raw + struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + inp = ggml_add(ctx0, inp, model.patch_bias); + + // position embeddings + struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings); + + // loop over layers + for (int il = 0; il < n_layer; il++) { + struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + + // layernorm1 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b); + } + + // self-attention + { + + struct ggml_tensor * Q = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); + + Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + + struct ggml_tensor * K = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); + + K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + + struct ggml_tensor * V = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); + + V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches); + } + + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, embeddings); + + embeddings = cur; // embeddings = residual, cur = hidden_states + + // layernorm2 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + + // siglip uses gelu + cur = ggml_gelu(ctx0, cur); + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + + // residual 2 + cur = ggml_add(ctx0, embeddings, cur); + + embeddings = cur; + } + + // post-layernorm + if (model.post_ln_w) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + } + + if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + const int batch_size = 1; + const int mm_tokens_per_image = 256; // default value for gemma3 + const int tokens_per_side = sqrt(mm_tokens_per_image); + const int patches_per_image = sqrt(num_patches); + const int kernel_size = patches_per_image / tokens_per_side; + + embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); + embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size); + + // doing a pool2d to reduce the number of output tokens to 256 + embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0); + embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size); + embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); + + // apply norm before projection + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w); + + // apply projection + embeddings = ggml_mul_mat(ctx0, + ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)), + embeddings); + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; +} + +static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { if (!ctx->has_vision_encoder) { LOG_ERR("This gguf file seems to have no vision encoder\n"); return nullptr; @@ -640,30 +539,27 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 int image_size_width = image_size; int image_size_height = image_size; if (ctx->has_minicpmv_projector) { - if (load_image_size == nullptr) { - load_image_size = clip_image_size_init(); - } - LOG_DBG("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height); - image_size_width = load_image_size->width; - image_size_height = load_image_size->height; + LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height); + image_size_width = load_image_size.width; + image_size_height = load_image_size.height; if (is_inf) { - image_size_width = imgs->data->nx; - image_size_height = imgs->data->ny; + image_size_width = imgs.entries[0]->nx; + image_size_height = imgs.entries[0]->ny; } } else if (ctx->has_qwen2vl_merger) { // use the image's native resolution when image is avaible if (is_inf) { // if (imgs->data->nx && imgs->data->ny) { - image_size_width = imgs->data->nx; - image_size_height = imgs->data->ny; + image_size_width = imgs.entries[0]->nx; + image_size_height = imgs.entries[0]->ny; } } const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int patches_w = image_size_width / patch_size; const int patches_h = image_size_height / patch_size; - const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions; const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; @@ -671,7 +567,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const float eps = hparams.eps; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; - const int batch_size = imgs->size; + const int batch_size = imgs.entries.size(); if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) { GGML_ASSERT(batch_size == 1); @@ -683,7 +579,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 /*.no_alloc =*/ true, }; - struct ggml_context * ctx0 = ggml_init(params); + ggml_context_ptr ctx0_ptr(ggml_init(params)); + auto ctx0 = ctx0_ptr.get(); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); @@ -715,7 +613,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); } - if (ctx->has_patch_bias) { + if (model.patch_bias) { // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); inp = ggml_add(ctx0, inp, model.patch_bias); } @@ -724,7 +622,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 if (ctx->has_llava_projector) { // concat class_embeddings and patch_embeddings - if (ctx->has_class_embedding) { + if (model.class_embedding) { embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); ggml_set_name(embeddings, "embeddings"); ggml_set_input(embeddings); @@ -761,7 +659,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } // pre-layernorm - if (ctx->has_pre_norm) { + if (model.pre_ln_w) { embeddings = ggml_norm(ctx0, embeddings, eps); ggml_set_name(embeddings, "pre_ln"); @@ -803,7 +701,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 ctx0, Q, positions, nullptr, d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); } - Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); @@ -827,7 +724,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_inplace(ctx0, KQ); + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -871,7 +768,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } // post-layernorm - if (ctx->has_post_norm) { + if (model.post_ln_w) { embeddings = ggml_norm(ctx0, embeddings, eps); ggml_set_name(embeddings, "post_ln"); @@ -911,8 +808,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_gelu(ctx0, embeddings); - embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + if (model.mm_2_w) { + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + } } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -1115,7 +1014,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); - Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b); struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b); // permute @@ -1129,7 +1027,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_inplace(ctx0, KQ); + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -1171,9 +1069,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); } } else { - GGML_ABORT("fatel error"); + GGML_ABORT("fatal error"); } - } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + } + else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -1190,572 +1089,539 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // build the graph ggml_build_forward_expand(gf, embeddings); - ggml_free(ctx0); - return gf; } -// read and create ggml_context containing the tensors and their data -struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { - struct ggml_context * meta = NULL; - - struct gguf_init_params params = { - /*.no_alloc = */ true, - /*.ctx = */ &meta, - }; - - struct gguf_context * ctx = gguf_init_from_file(fname, params); - if (!ctx) { - throw std::runtime_error(format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname)); +static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { + if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + return clip_image_build_graph_siglip(ctx, imgs); + } else { + // TODO: we should have one build_* function per model + return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf); } +} - if (verbosity >= 1) { - const int n_tensors = gguf_get_n_tensors(ctx); - const int n_kv = gguf_get_n_kv(ctx); - const int ftype = get_u32(ctx, KEY_FTYPE); - const std::string ftype_str = get_ftype(ftype); - const int idx_desc = get_key_idx(ctx, KEY_DESCRIPTION); - const std::string description = gguf_get_val_str(ctx, idx_desc); - const int idx_name = gguf_find_key(ctx, KEY_NAME); - if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug - const std::string name = gguf_get_val_str(ctx, idx_name); - LOG_INF("%s: model name: %s\n", __func__, name.c_str()); - } - LOG_INF("%s: description: %s\n", __func__, description.c_str()); - LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx)); - LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); - LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors); - LOG_INF("%s: n_kv: %d\n", __func__, n_kv); - LOG_INF("%s: ftype: %s\n", __func__, ftype_str.c_str()); - LOG_INF("\n"); - } - const int n_tensors = gguf_get_n_tensors(ctx); +struct clip_model_loader { + ggml_context_ptr ctx_meta; + gguf_context_ptr ctx_gguf; - // kv - const int n_kv = gguf_get_n_kv(ctx); - LOG_INF("%s: loaded meta data with %d key-value pairs and %d tensors from %s\n", - __func__, n_kv, n_tensors, fname); - { - std::map n_type; + clip_ctx & ctx_clip; + std::string fname; - for (int i = 0; i < n_tensors; i++) { - enum ggml_type type = gguf_get_tensor_type(ctx, i); + size_t model_size; // in bytes - n_type[type]++; + // TODO @ngxson : we should not pass clip_ctx here, it should be clip_vision_model + clip_model_loader(const char * fname, clip_ctx & ctx_clip) : ctx_clip(ctx_clip), fname(fname) { + struct ggml_context * meta = nullptr; + + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &meta, + }; + + ctx_gguf = gguf_context_ptr(gguf_init_from_file(fname, params)); + if (!ctx_gguf.get()) { + throw std::runtime_error(string_format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname)); } - LOG_INF("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); - for (int i = 0; i < n_kv; i++) { - const char * name = gguf_get_key(ctx, i); - const enum gguf_type type = gguf_get_kv_type(ctx, i); - const std::string type_name = - type == GGUF_TYPE_ARRAY - ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(ctx, i)), gguf_get_arr_n(ctx, i)) - : gguf_type_name(type); + ctx_meta.reset(meta); - std::string value = gguf_kv_to_str(ctx, i); - const size_t MAX_VALUE_LEN = 40; - if (value.size() > MAX_VALUE_LEN) { - value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); - } - replace_all(value, "\n", "\\n"); + const int n_tensors = gguf_get_n_tensors(ctx_gguf.get()); - LOG_INF("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); + // print gguf info + { + std::string name; + get_string(KEY_NAME, name, false); + std::string description; + get_string(KEY_DESCRIPTION, description, false); + LOG_INF("%s: model name: %s\n", __func__, name.c_str()); + LOG_INF("%s: description: %s\n", __func__, description.c_str()); + LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx_gguf.get())); + LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx_gguf.get())); + LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors); + LOG_INF("%s: n_kv: %d\n", __func__, (int)gguf_get_n_kv(ctx_gguf.get())); + LOG_INF("\n"); } - // print type counts - for (auto & kv : n_type) { - if (kv.second == 0) { - continue; - } - - LOG_INF("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); - } - } - - // data - size_t model_size = 0; - { - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx, i); - const size_t offset = gguf_get_tensor_offset(ctx, i); - enum ggml_type type = gguf_get_tensor_type(ctx, i); - struct ggml_tensor * cur = ggml_get_tensor(meta, name); - size_t tensor_size = ggml_nbytes(cur); - model_size += tensor_size; - if (verbosity >= 3) { - LOG_INF("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n", - __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type)); + // tensors + { + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf.get(), i); + const size_t offset = gguf_get_tensor_offset(ctx_gguf.get(), i); + enum ggml_type type = gguf_get_tensor_type(ctx_gguf.get(), i); + struct ggml_tensor * cur = ggml_get_tensor(meta, name); + size_t tensor_size = ggml_nbytes(cur); + model_size += tensor_size; + LOG_DBG("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n", + __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type)); } } } - clip_ctx * new_clip = new clip_ctx{}; - - // update projector type - { - int idx = gguf_find_key(ctx, KEY_PROJ_TYPE); - if (idx != -1) { - const std::string proj_type = gguf_get_val_str(ctx, idx); - new_clip->proj_type = clip_projector_type_from_string(proj_type); - } else { - new_clip->proj_type = PROJECTOR_TYPE_MLP; - } - - if (new_clip->proj_type == PROJECTOR_TYPE_MLP) { - if (gguf_find_tensor(ctx, format(TN_LLAVA_PROJ, 3, "weight").c_str()) != -1) { - new_clip->proj_type = PROJECTOR_TYPE_MLP_NORM; + void load_hparams() { + // projector type + { + std::string proj_type; + get_string(KEY_PROJ_TYPE, proj_type, false); + if (!proj_type.empty()) { + ctx_clip.proj_type = clip_projector_type_from_string(proj_type); + } + if (ctx_clip.proj_type == PROJECTOR_TYPE_UNKNOWN) { + throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str())); } } - } - ggml_backend_t backend = ggml_backend_init_best(); - if (backend == nullptr) { - LOG_ERR("%s: failed to initialize backend\n", __func__); - clip_free(new_clip); - gguf_free(ctx); - return nullptr; - } - LOG_INF("%s: using %s backend\n", __func__, ggml_backend_name(backend)); - new_clip->backend = backend; + // other hparams + { + get_bool(KEY_HAS_TEXT_ENC, ctx_clip.has_text_encoder, false); + get_bool(KEY_HAS_VIS_ENC, ctx_clip.has_vision_encoder, false); + GGML_ASSERT(ctx_clip.has_vision_encoder); + GGML_ASSERT(!ctx_clip.has_text_encoder); - // model size and capabilities - { - int idx = get_key_idx(ctx, KEY_HAS_TEXT_ENC); - new_clip->has_text_encoder = gguf_get_val_bool(ctx, idx); + // legacy keys, use KEY_PROJ_TYPE instead + get_bool(KEY_HAS_LLAVA_PROJ, ctx_clip.has_llava_projector, false); + get_bool(KEY_HAS_MINICPMV_PROJ, ctx_clip.has_minicpmv_projector, false); + get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); + get_bool(KEY_HAS_GLM_PROJ, ctx_clip.has_glm_projector, false); + get_bool(KEY_HAS_QWEN2VL_MERGER, ctx_clip.has_qwen2vl_merger, false); + // !!! do NOT extend the list above, use KEY_PROJ_TYPE instead - idx = get_key_idx(ctx, KEY_HAS_VIS_ENC); - new_clip->has_vision_encoder = gguf_get_val_bool(ctx, idx); + get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false); + get_bool(KEY_USE_SILU, ctx_clip.use_silu, false); - idx = gguf_find_key(ctx, KEY_HAS_LLAVA_PROJ); - if (idx != -1) { - new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx); - } + auto & hparams = ctx_clip.vision_model.hparams; + get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size); + get_u32(string_format(KEY_N_HEAD, "vision"), hparams.n_head); + get_u32(string_format(KEY_N_FF, "vision"), hparams.n_intermediate); + get_u32(string_format(KEY_N_BLOCK, "vision"), hparams.n_layer); + get_u32(string_format(KEY_PROJ_DIM, "vision"), hparams.projection_dim); + get_f32(string_format(KEY_LAYER_NORM_EPS, "vision"), hparams.eps); + get_u32(KEY_IMAGE_SIZE, hparams.image_size); + get_u32(KEY_PATCH_SIZE, hparams.patch_size); + get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); + get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); - idx = gguf_find_key(ctx, KEY_HAS_MINICPMV_PROJ); - if (idx != -1) { - new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx); - } + { + std::string mm_patch_merge_type; + get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false); + if (mm_patch_merge_type == "spatial_unpad") { + hparams.mm_patch_merge_type = PATCH_MERGE_SPATIAL_UNPAD; + } + } - idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION); - if (idx != -1) { - new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); - } + { + int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN); + int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD); + GGML_ASSERT(idx_mean >= 0 && "image_mean not found"); + GGML_ASSERT(idx_std >= 0 && "image_std not found"); + const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean); + const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std); + for (int i = 0; i < 3; ++i) { + ctx_clip.image_mean[i] = mean_data[i]; + ctx_clip.image_std[i] = std_data[i]; + } + } - idx = gguf_find_key(ctx, KEY_HAS_GLM_PROJ); - if (idx != -1) { - new_clip->has_glm_projector = gguf_get_val_bool(ctx, idx); - } + // Load the vision feature layer indices if they are explicitly provided; + // if multiple vision feature layers are present, the values will be concatenated + // to form the final visual features. + // NOTE: gguf conversions should standardize the values of the vision feature layer to + // be non-negative, since we use -1 to mark values as unset here. + std::vector vision_feature_layer; + get_arr_int(KEY_FEATURE_LAYER, vision_feature_layer, false); + // convert std::vector to std::unordered_set + for (auto & layer : vision_feature_layer) { + hparams.vision_feature_layer.insert(layer); + } + // Calculate the deepest feature layer based on hparams and projector type + ctx_clip.max_feature_layer = get_deepest_feature_layer(&ctx_clip); - idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER); - if (idx != -1) { - new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx); - } - // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search - - GGML_ASSERT(new_clip->has_vision_encoder); - GGML_ASSERT(!new_clip->has_text_encoder); - - idx = get_key_idx(ctx, KEY_USE_GELU); - new_clip->use_gelu = gguf_get_val_bool(ctx, idx); - - try { - idx = get_key_idx(ctx, KEY_USE_SILU); - new_clip->use_silu = gguf_get_val_bool(ctx, idx); - } catch (std::runtime_error & /*e*/) { - new_clip->use_silu = false; - } - - if (verbosity >= 1) { - LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder); - LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); - LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); - LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); - LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector); - LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); - LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); + LOG_INF("%s: text_encoder: %d\n", __func__, ctx_clip.has_text_encoder); + LOG_INF("%s: vision_encoder: %d\n", __func__, ctx_clip.has_vision_encoder); + LOG_INF("%s: llava_projector: %d\n", __func__, ctx_clip.has_llava_projector); + LOG_INF("%s: minicpmv_projector: %d\n", __func__, ctx_clip.has_minicpmv_projector); + LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version); + LOG_INF("%s: glm_projector: %d\n", __func__, ctx_clip.has_glm_projector); + LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); + LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0); } } - LOG_INF("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors); + void load_tensors() { + std::map tensor_offset; + std::vector tensors_to_load; - // load tensors - { - std::vector read_buf; + // get offsets + for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf.get(), i); + tensor_offset[name] = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), i); + } + + // create data context struct ggml_init_params params = { - /*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(), + /*.mem_size =*/ (gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - - new_clip->ctx_data = ggml_init(params); - if (!new_clip->ctx_data) { - LOG_ERR("%s: ggml_init() failed\n", __func__); - clip_free(new_clip); - gguf_free(ctx); - return nullptr; - } -#ifdef _WIN32 - int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0); - if (!wlen) { - return NULL; - } - wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t)); - wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wbuf, wlen); - if (!wlen) { - free(wbuf); - return NULL; - } -#if __GLIBCXX__ - int fd = _wopen(wbuf, _O_RDONLY | _O_BINARY); - __gnu_cxx::stdio_filebuf buffer(fd, std::ios_base::in); - std::istream fin(&buffer); -#else // MSVC - // unused in our current build - auto fin = std::ifstream(wbuf, std::ios::binary); -#endif - free(wbuf); -#else - auto fin = std::ifstream(fname, std::ios::binary); -#endif - if (!fin) { - LOG_ERR("cannot open model file for loading tensors\n"); - clip_free(new_clip); - gguf_free(ctx); - return nullptr; + ctx_clip.ctx_data.reset(ggml_init(params)); + if (!ctx_clip.ctx_data) { + throw std::runtime_error(string_format("%s: failed to init ggml context\n", __func__)); } - // add tensors to context - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx, i); - struct ggml_tensor * t = ggml_get_tensor(meta, name); - struct ggml_tensor * cur = ggml_dup_tensor(new_clip->ctx_data, t); - ggml_set_name(cur, name); - } - - // alloc memory and offload data - new_clip->params_buffer = ggml_backend_alloc_ctx_tensors(new_clip->ctx_data, new_clip->backend); - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx, i); - struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name); - const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i); - fin.seekg(offset, std::ios::beg); - if (!fin) { - LOG_ERR("%s: failed to seek for tensor %s\n", __func__, name); - clip_free(new_clip); - gguf_free(ctx); - return nullptr; + // helper function + auto get_tensor = [&](const std::string & name, bool required = true) { + struct ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str()); + if (!cur && required) { + throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str())); } - int num_bytes = ggml_nbytes(cur); - if (ggml_backend_buffer_is_host(new_clip->params_buffer)) { - // for the CPU and Metal backend, we can read directly into the tensor - fin.read(reinterpret_cast(cur->data), num_bytes); - } else { - // read into a temporary buffer first, then copy to device memory - read_buf.resize(num_bytes); - fin.read(reinterpret_cast(read_buf.data()), num_bytes); - ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + if (cur) { + tensors_to_load.push_back(cur); + // add tensors to context + struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur); + ggml_set_name(data_tensor, cur->name); + cur = data_tensor; } - } -#if defined(_WIN32) && defined(__GLIBCXX__) - close(fd); -#else - fin.close(); -#endif - } + return cur; + }; - // vision model - if (new_clip->has_vision_encoder) { - // load vision model - auto & vision_model = new_clip->vision_model; - auto & hparams = vision_model.hparams; - hparams.hidden_size = get_u32(ctx, format(KEY_N_EMBD, "vision")); - hparams.n_head = get_u32(ctx, format(KEY_N_HEAD, "vision")); - hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision")); - hparams.n_layer = get_u32(ctx, format(KEY_N_BLOCK, "vision")); - hparams.image_size = get_u32(ctx, KEY_IMAGE_SIZE); - hparams.patch_size = get_u32(ctx, KEY_PATCH_SIZE); - hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision")); - hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision")); + auto & vision_model = ctx_clip.vision_model; - try { - int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS); - int n = gguf_get_arr_n(ctx, idx); - const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx); - for (int i = 0; i < n; ++i) { - hparams.image_grid_pinpoints.push_back(pinpoints[i]); - } - } catch (std::runtime_error & /*e*/) { } + vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false); - // Load the vision feature layer indices if they are explicitly provided; - // if multiple vision feature layers are present, the values will be concatenated - // to form the final visual features. - // NOTE: gguf conversions should standardize the values of the vision feature layer to - // be non-negative, since we use -1 to mark values as unset here. - try { - int idx = get_key_idx(ctx, KEY_FEATURE_LAYER); - int n = gguf_get_arr_n(ctx, idx); + vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false); + vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false); - const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); + vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false); + vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false); - for (int i = 0; i < n; ++i) { - hparams.vision_feature_layer.insert(vision_feature_layer[i]); - } - } catch (std::runtime_error & /*e*/) { } - - try { - int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); - strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx)); - } catch (std::runtime_error & /*e*/) { - strcpy(hparams.mm_patch_merge_type, "flat"); + vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false); + vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); + vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); + if (vision_model.patch_embeddings_1 == nullptr) { + ctx_clip.has_qwen2vl_merger = false; } - try { - hparams.image_crop_resolution = get_u32(ctx, KEY_IMAGE_CROP_RESOLUTION); // llava-1.6 - } catch(const std::exception& /*e*/) { - hparams.image_crop_resolution = hparams.image_size; - } + vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false); - int idx_mean = get_key_idx(ctx, KEY_IMAGE_MEAN); - int idx_std = get_key_idx(ctx, KEY_IMAGE_STD); - - const float * mean_data = (const float *)gguf_get_arr_data(ctx, idx_mean); - const float * std_data = (const float *)gguf_get_arr_data(ctx, idx_std); - - for (int i = 0; i < 3; ++i) { - new_clip->image_mean[i] = mean_data[i]; - new_clip->image_std[i] = std_data[i]; - } - - // Calculate the deepest feature layer based on hparams and projector type - new_clip->max_feature_layer = get_deepest_feature_layer(new_clip); - - if (verbosity >= 2) { - LOG_INF("\n%s: vision model hparams\n", __func__); - LOG_INF("image_size %d\n", hparams.image_size); - LOG_INF("patch_size %d\n", hparams.patch_size); - LOG_INF("v_hidden_size %d\n", hparams.hidden_size); - LOG_INF("v_n_intermediate %d\n", hparams.n_intermediate); - LOG_INF("v_projection_dim %d\n", hparams.projection_dim); - LOG_INF("v_n_head %d\n", hparams.n_head); - LOG_INF("v_n_layer %d\n", hparams.n_layer); - LOG_INF("v_eps %f\n", hparams.eps); - LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]); - LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]); - LOG_INF("v_image_grid_pinpoints: "); - for (const auto & pp : hparams.image_grid_pinpoints) { - LOG_INF("%d ", pp); - } - LOG_INF("\n"); - LOG_INF("v_vision_feature_layer: "); - for (const auto & feature_layer: hparams.vision_feature_layer) { - LOG_INF("%d ", feature_layer); - } - LOG_INF("\n"); - LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type); - - } - - try { - vision_model.class_embedding = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD); - new_clip->has_class_embedding = true; - } catch (const std::exception& /*e*/) { - new_clip->has_class_embedding = false; - } - - try { - vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight")); - vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias")); - new_clip->has_pre_norm = true; - } catch (std::exception & /*e*/) { - new_clip->has_pre_norm = false; - } - - try { - vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight")); - vision_model.post_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias")); - new_clip->has_post_norm = true; - } catch (std::exception & /*e*/) { - new_clip->has_post_norm = false; - } - - try { - vision_model.patch_bias = get_tensor(new_clip->ctx_data, TN_PATCH_BIAS); - new_clip->has_patch_bias = true; - } catch (std::exception & /*e*/) { - new_clip->has_patch_bias = false; - } - - try { - vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); - vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v")); - } catch(const std::exception& /*e*/) { - LOG_ERR("%s: failed to load vision model tensors\n", __func__); - } - try { - vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1); - } catch(const std::exception& /*e*/) { - new_clip->has_qwen2vl_merger = false; - } - - // LLaVA projection - if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) { - vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); - vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); - try { - // Yi-type llava - vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "weight")); - vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "bias")); - } catch (std::runtime_error & /*e*/) { } - try { - // missing in Yi-type llava - vision_model.mm_2_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_2_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); - } catch (std::runtime_error & /*e*/) { } - try { - // Yi-type llava - vision_model.mm_3_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "weight")); - vision_model.mm_3_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "bias")); - } catch (std::runtime_error & /*e*/) { } - try { - // Yi-type llava - vision_model.mm_4_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "weight")); - vision_model.mm_4_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "bias")); - } catch (std::runtime_error & /*e*/) { } - try { - vision_model.image_newline = get_tensor(new_clip->ctx_data, TN_IMAGE_NEWLINE); - // LOG_INF("%s: image_newline tensor (llava-1.6) found\n", __func__); - } catch (std::runtime_error & /*e*/) { } - } else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) { - // MobileVLM projection - vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "weight")); - vision_model.mm_model_mlp_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "bias")); - vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "weight")); - vision_model.mm_model_mlp_3_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "bias")); - vision_model.mm_model_block_1_block_0_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight")); - vision_model.mm_model_block_1_block_0_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight")); - vision_model.mm_model_block_1_block_0_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias")); - vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight")); - vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias")); - vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight")); - vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias")); - vision_model.mm_model_block_1_block_2_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight")); - vision_model.mm_model_block_1_block_2_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight")); - vision_model.mm_model_block_1_block_2_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias")); - vision_model.mm_model_block_2_block_0_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight")); - vision_model.mm_model_block_2_block_0_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight")); - vision_model.mm_model_block_2_block_0_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias")); - vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight")); - vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias")); - vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight")); - vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias")); - vision_model.mm_model_block_2_block_2_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight")); - vision_model.mm_model_block_2_block_2_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight")); - vision_model.mm_model_block_2_block_2_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias")); - } - else if (new_clip->proj_type == PROJECTOR_TYPE_LDPV2) - { - // MobilVLM_V2 projection - vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "weight")); - vision_model.mm_model_mlp_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "bias")); - vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "weight")); - vision_model.mm_model_mlp_2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "bias")); - vision_model.mm_model_peg_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "weight")); - vision_model.mm_model_peg_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "bias")); - } - else if (new_clip->proj_type == PROJECTOR_TYPE_RESAMPLER) { - // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); - vision_model.mm_model_pos_embed_k = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD_K); - vision_model.mm_model_query = get_tensor(new_clip->ctx_data, TN_MINICPMV_QUERY); - vision_model.mm_model_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_PROJ); - vision_model.mm_model_kv_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_KV_PROJ); - vision_model.mm_model_attn_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "weight")); - vision_model.mm_model_attn_k_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "weight")); - vision_model.mm_model_attn_v_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "weight")); - vision_model.mm_model_attn_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "bias")); - vision_model.mm_model_attn_k_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "bias")); - vision_model.mm_model_attn_v_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "bias")); - vision_model.mm_model_attn_o_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "weight")); - vision_model.mm_model_attn_o_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "bias")); - vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "weight")); - vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "bias")); - vision_model.mm_model_ln_kv_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "weight")); - vision_model.mm_model_ln_kv_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "bias")); - vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); - vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); - } - else if (new_clip->proj_type == PROJECTOR_TYPE_GLM_EDGE) { - vision_model.mm_model_adapter_conv_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "weight")); - vision_model.mm_model_adapter_conv_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "bias")); - vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_LINEAR,"weight")); - vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"weight")); - vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"bias")); - vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_H_2_4H,"weight")); - vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_GATE,"weight")); - vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); - vision_model.boi_w = get_tensor(new_clip->ctx_data, TN_GLM_BOI_W); - vision_model.eoi_w = get_tensor(new_clip->ctx_data, TN_GLM_EOI_W); - } - else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) { - vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); - vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); - vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); - } - else { - std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type]; - throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); - } - - vision_model.layers.resize(hparams.n_layer); - - for (int il = 0; il < hparams.n_layer; ++il) { + // layers + vision_model.layers.resize(vision_model.hparams.n_layer); + for (int il = 0; il < vision_model.hparams.n_layer; ++il) { auto & layer = vision_model.layers[il]; - layer.k_w = get_tensor(new_clip->ctx_data, format(TN_ATTN_K, "v", il, "weight")); - layer.q_w = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q, "v", il, "weight")); - layer.v_w = get_tensor(new_clip->ctx_data, format(TN_ATTN_V, "v", il, "weight")); - layer.o_w = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "weight")); - layer.ln_1_w = get_tensor(new_clip->ctx_data, format(TN_LN_1, "v", il, "weight")); - layer.ln_2_w = get_tensor(new_clip->ctx_data, format(TN_LN_2, "v", il, "weight")); - layer.ff_i_w = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN, "v", il, "weight")); - layer.ff_o_w = get_tensor(new_clip->ctx_data, format(TN_FFN_UP, "v", il, "weight")); - layer.k_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_K, "v", il, "bias")); - layer.q_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q, "v", il, "bias")); - layer.v_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_V, "v", il, "bias")); - layer.o_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "bias")); - layer.ln_1_b = get_tensor(new_clip->ctx_data, format(TN_LN_1, "v", il, "bias")); - layer.ln_2_b = get_tensor(new_clip->ctx_data, format(TN_LN_2, "v", il, "bias")); - layer.ff_i_b = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN, "v", il, "bias")); - layer.ff_o_b = get_tensor(new_clip->ctx_data, format(TN_FFN_UP, "v", il, "bias")); + layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight")); + layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight")); + layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight")); + layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight")); + layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false); + layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false); + layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight")); + layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); + layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false); + layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false); + layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false); + layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false); + layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); + layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); + layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false); + layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); + } + + switch (ctx_clip.proj_type) { + case PROJECTOR_TYPE_MLP: + case PROJECTOR_TYPE_MLP_NORM: + { + // LLaVA projection + vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false); + vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + // Yi-type llava + vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false); + vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); + // missing in Yi-type llava + vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false); + vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + // Yi-type llava + vision_model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false); + vision_model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false); + vision_model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false); + vision_model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false); + if (vision_model.mm_3_w) { + // TODO: this is a hack to support Yi-type llava + ctx_clip.proj_type = PROJECTOR_TYPE_MLP_NORM; + } + vision_model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); + } break; + case PROJECTOR_TYPE_LDP: + { + // MobileVLM projection + vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + vision_model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); + vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + vision_model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); + vision_model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight")); + vision_model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight")); + vision_model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias")); + vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight")); + vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias")); + vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight")); + vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias")); + vision_model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight")); + vision_model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight")); + vision_model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias")); + vision_model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight")); + vision_model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight")); + vision_model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias")); + vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight")); + vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias")); + vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight")); + vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias")); + vision_model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight")); + vision_model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight")); + vision_model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias")); + } break; + case PROJECTOR_TYPE_LDPV2: + { + // MobilVLM_V2 projection + vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + vision_model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); + vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); + vision_model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias")); + vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight")); + vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias")); + } break; + case PROJECTOR_TYPE_RESAMPLER: + { + // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); + vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K); + vision_model.mm_model_query = get_tensor(TN_MINICPMV_QUERY); + vision_model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ); + vision_model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ); + vision_model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight")); + vision_model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight")); + vision_model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight")); + vision_model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias")); + vision_model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias")); + vision_model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias")); + vision_model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight")); + vision_model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias")); + vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight")); + vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias")); + vision_model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight")); + vision_model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias")); + vision_model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight")); + vision_model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias")); + } break; + case PROJECTOR_TYPE_GLM_EDGE: + { + vision_model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight")); + vision_model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias")); + vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR,"weight")); + vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1,"weight")); + vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1,"bias")); + vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H,"weight")); + vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE,"weight")); + vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); + vision_model.boi_w = get_tensor(TN_GLM_BOI_W); + vision_model.eoi_w = get_tensor(TN_GLM_EOI_W); + } break; + case PROJECTOR_TYPE_MERGER: + { + vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; + case PROJECTOR_TYPE_GEMMA3: + { + vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); + vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); + } break; + default: + GGML_ASSERT(false && "unknown projector type"); + } + + // load data + { + std::vector read_buf; + +#ifdef _WIN32 + int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0); + if (!wlen) { + throw std::runtime_error(string_format("%s: failed to convert filename to wide string\n", __func__)); + } + wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t)); + wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wbuf, wlen); + if (!wlen) { + free(wbuf); + throw std::runtime_error(string_format("%s: failed to convert filename to wide string\n", __func__)); + } +#if __GLIBCXX__ + int fd = _wopen(wbuf, _O_RDONLY | _O_BINARY); + __gnu_cxx::stdio_filebuf buffer(fd, std::ios_base::in); + std::istream fin(&buffer); +#else // MSVC + // unused in our current build + auto fin = std::ifstream(wbuf, std::ios::binary); +#endif + free(wbuf); +#else + auto fin = std::ifstream(fname, std::ios::binary); +#endif + if (!fin) { + throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); + } + + // alloc memory and offload data + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); + ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); + ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + for (auto & t : tensors_to_load) { + struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); + const size_t offset = tensor_offset[t->name]; + fin.seekg(offset, std::ios::beg); + if (!fin) { + throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name)); + } + size_t num_bytes = ggml_nbytes(cur); + if (ggml_backend_buft_is_host(buft)) { + // for the CPU and Metal backend, we can read directly into the tensor + fin.read(reinterpret_cast(cur->data), num_bytes); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + } + } +#if defined(_WIN32) && defined(__GLIBCXX__) + close(fd); +#else + fin.close(); +#endif + + LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str()); } } - ggml_free(meta); + void alloc_compute_meta() { + ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); - new_clip->ctx_gguf = ctx; - - // measure mem requirement and allocate - { - new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); - new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend)); + // create a fake batch clip_image_f32_batch batch; - batch.size = 1; - batch.data = nullptr; - ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false); - ggml_gallocr_reserve(new_clip->compute_alloc, gf); - size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0); - LOG_INF("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0); + clip_image_f32_ptr img(clip_image_f32_init()); + clip_image_size image_size; + image_size.width = clip_get_image_size(&ctx_clip); + image_size.height = clip_get_image_size(&ctx_clip); + int n_patches = clip_get_image_size(&ctx_clip) / image_size.width; + img->nx = n_patches; + img->ny = n_patches; + img->buf.resize(n_patches * image_size.width * image_size.height * 3); + batch.entries.push_back(std::move(img)); + + ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false); + ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); + for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) { + ggml_backend_t backend = ctx_clip.backend_ptrs[i]; + ggml_backend_buffer_type_t buft = ctx_clip.backend_buft[i]; + size_t size = ggml_backend_sched_get_buffer_size(ctx_clip.sched.get(), backend); + if (size > 1) { + LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + size / 1024.0 / 1024.0); + } + } } - return new_clip; + void get_bool(const std::string & key, bool & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + output = gguf_get_val_bool(ctx_gguf.get(), i); + } + + void get_i32(const std::string & key, int & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + output = gguf_get_val_i32(ctx_gguf.get(), i); + } + + void get_u32(const std::string & key, int & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + output = gguf_get_val_u32(ctx_gguf.get(), i); + } + + void get_f32(const std::string & key, float & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + output = gguf_get_val_f32(ctx_gguf.get(), i); + } + + void get_string(const std::string & key, std::string & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + output = std::string(gguf_get_val_str(ctx_gguf.get(), i)); + } + + void get_arr_int(const std::string & key, std::vector & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + int n = gguf_get_arr_n(ctx_gguf.get(), i); + output.resize(n); + const int32_t * values = (const int32_t *)gguf_get_arr_data(ctx_gguf.get(), i); + for (int i = 0; i < n; ++i) { + output[i] = values[i]; + } + } +}; + +// read and create ggml_context containing the tensors and their data +struct clip_ctx * clip_model_load(const char * fname, const int verbosity) { + return clip_init(fname, clip_context_params{ + /* use_gpu */ true, + /* verbosity */ static_cast(verbosity), + }); +} + +struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) { + g_logger_state.verbosity_thold = ctx_params.verbosity; + clip_ctx * ctx_clip = new clip_ctx(ctx_params); + + try { + clip_model_loader loader(fname, *ctx_clip); + loader.load_hparams(); + loader.load_tensors(); + loader.alloc_compute_meta(); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what()); + delete ctx_clip; + return nullptr; + } + + return ctx_clip; } void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) { - ctx_clip->load_image_size = load_image_size; + ctx_clip->load_image_size = *load_image_size; // copy } struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) { - return ctx_clip->load_image_size; + return &ctx_clip->load_image_size; } struct clip_image_size * clip_image_size_init() { @@ -1773,22 +1639,56 @@ struct clip_image_f32 * clip_image_f32_init() { return new clip_image_f32(); } -void clip_image_u8_free(struct clip_image_u8 * img) { delete img; } -void clip_image_f32_free(struct clip_image_f32 * img) { delete img; } -void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { - if (batch->size > 0) { - delete[] batch->data; - batch->size = 0; - } -} -void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { - if (batch->size > 0) { - delete[] batch->data; - batch->size = 0; - } +struct clip_image_f32_batch * clip_image_f32_batch_init() { + return new clip_image_f32_batch(); } -void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img) { +unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) { + if (nx) *nx = img->nx; + if (ny) *ny = img->ny; + return img->buf.data(); +} + +void clip_image_size_free(struct clip_image_size * load_image_size) { + if (load_image_size == nullptr) { + return; + } + delete load_image_size; +} +void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; } +void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; } +void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; } +void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; } + +size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) { + return batch->entries.size(); +} + +size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) { + if (idx < 0 || idx >= (int)batch->entries.size()) { + LOG_ERR("%s: invalid index %d\n", __func__, idx); + return 0; + } + return batch->entries[idx]->nx; +} + +size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) { + if (idx < 0 || idx >= (int)batch->entries.size()) { + LOG_ERR("%s: invalid index %d\n", __func__, idx); + return 0; + } + return batch->entries[idx]->ny; +} + +clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) { + if (idx < 0 || idx >= (int)batch->entries.size()) { + LOG_ERR("%s: invalid index %d\n", __func__, idx); + return nullptr; + } + return batch->entries[idx].get(); +} + +void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) { img->nx = nx; img->ny = ny; img->buf.resize(3 * nx * ny); @@ -1859,14 +1759,15 @@ static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int ta } // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not -static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) { - dst->nx = src->nx; - dst->ny = src->ny; - dst->buf.resize(src->buf.size()); +static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) { + dst.nx = src.nx; + dst.ny = src.ny; + dst.buf.resize(src.buf.size()); - for (size_t i = 0; i < src->buf.size(); ++i) { + // TODO @ngxson : seems like this could be done more efficiently on cgraph + for (size_t i = 0; i < src.buf.size(); ++i) { int c = i % 3; // rgb - dst->buf[i] = (static_cast(src->buf[i]) / 255.0f - mean[c]) / std[c]; + dst.buf[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; } } @@ -1874,7 +1775,7 @@ inline int clip(int x, int lower, int upper) { return std::max(lower, std::min(x, upper)); } -static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) { +static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { const int nx = img.nx; const int ny = img.ny; @@ -2012,13 +1913,13 @@ static std::pair select_best_resolution(const std::pair & or return best_fit; } -static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { - std::vector patches; +static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { + std::vector patches; int width = image.nx; int height = image.ny; for (int i = 0; i < height; i += patch_size) { for (int j = 0; j < width; j += patch_size) { - clip_image_u8 *patch = clip_image_u8_init(); + clip_image_u8_ptr patch(clip_image_u8_init()); patch->nx = std::min(patch_size, width - j); patch->ny = std::min(patch_size, height - i); patch->buf.resize(3 * patch->nx * patch->ny); @@ -2029,7 +1930,7 @@ static std::vector divide_to_patches_u8(const clip_image_u8 & im } } } - patches.push_back(patch); + patches.push_back(std::move(patch)); } } return patches; @@ -2110,7 +2011,7 @@ static std::pair uhd_best_grid(const int max_slice_nums, const int mul // -> https://arxiv.org/pdf/2403.11703 // -> https://github.com/thunlp/LLaVA-UHD // -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118 -static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { +static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { const std::pair original_size={img->nx,img->ny}; const int original_width = img->nx; const int original_height = img->ny; @@ -2118,33 +2019,33 @@ static std::vector> uhd_slice_image(const clip_imag const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); const int multiple = fmin(ceil(ratio), max_slice_nums); - std::vector> images; - LOG_INF("%s: multiple %d\n", __func__, multiple); - images.push_back(std::vector()); + std::vector> images; + LOG_DBG("%s: multiple %d\n", __func__, multiple); + images.push_back(std::vector()); if (multiple <= 1) { auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true); - clip_image_u8 * source_image = clip_image_u8_init(); + clip_image_u8_ptr source_image(clip_image_u8_init()); bicubic_resize(*img, *source_image, best_size.first, best_size.second); // source_image = image.resize(best_size, Image.Resampling.BICUBIC) - images[images.size()-1].push_back(source_image); + images.back().push_back(std::move(source_image)); } else if (multiple > 1) { auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size); - clip_image_u8 * source_image = clip_image_u8_init(); + clip_image_u8_ptr source_image(clip_image_u8_init()); bicubic_resize(*img, *source_image, best_size.first, best_size.second); // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) - LOG_INF("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); - images[images.size()-1].push_back(source_image); + LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); + images.back().push_back(std::move(source_image)); std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); - LOG_INF("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); + LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); - clip_image_u8 * refine_image = clip_image_u8_init(); + clip_image_u8_ptr refine_image(clip_image_u8_init()); bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second); - LOG_INF("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); + LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); // split_to_patches int width = refine_image->nx; @@ -2152,9 +2053,9 @@ static std::vector> uhd_slice_image(const clip_imag int grid_x = int(width / best_grid.first); int grid_y = int(height / best_grid.second); for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){ - images.push_back(std::vector()); + images.push_back(std::vector()); for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){ - clip_image_u8 * patch = clip_image_u8_init(); + clip_image_u8_ptr patch(clip_image_u8_init()); patch->nx = grid_x; patch->ny = grid_y; patch->buf.resize(3 * patch->nx * patch->ny); @@ -2167,10 +2068,9 @@ static std::vector> uhd_slice_image(const clip_imag patch->buf[j+2] = refine_image->buf[i+2]; } } - images[images.size()-1].push_back(patch); + images.back().push_back(std::move(patch)); } } - clip_image_u8_free(refine_image); } return images; } @@ -2178,8 +2078,8 @@ static std::vector> uhd_slice_image(const clip_imag int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { const int max_slice_nums=9; const int scale_resolution=448; - const int original_width = ctx_clip->load_image_size->width; - const int original_height = ctx_clip->load_image_size->height; + const int original_width = ctx_clip->load_image_size.width; + const int original_height = ctx_clip->load_image_size.height; const float log_ratio = log(1.0*original_width/original_height); const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); const int multiple = fmin(ceil(ratio), max_slice_nums); @@ -2189,88 +2089,64 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found -bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) { +bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { - if(clip_is_minicpmv(ctx)){ + if (clip_is_minicpmv(ctx)) { int max_slice_nums = 9; - std::vector> imgs = uhd_slice_image(img, max_slice_nums); - res_imgs->size = 0; - for (size_t i = 0; i < imgs.size(); ++i){ - res_imgs->size += imgs[i].size(); - } - res_imgs->data = new clip_image_f32[res_imgs->size]; - int idx = 0; + std::vector> imgs = uhd_slice_image(img, max_slice_nums); for (size_t i = 0; i < imgs.size(); ++i) { for (size_t j = 0; j < imgs[i].size(); ++j) { LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); - clip_image_f32 * res = clip_image_f32_init(); - normalize_image_u8_to_f32(imgs[i][j], res, ctx->image_mean, ctx->image_std); - res_imgs->data[idx++] = *res; - clip_image_f32_free(res); - } - } - for (size_t i = 0; i < imgs.size(); ++i) { - for (size_t j = 0; j < imgs[i].size(); ++j) { - if (imgs[i][j] != nullptr) { - clip_image_u8_free(imgs[i][j]); - } + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i][j], *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); } } return true; } else if (ctx->has_qwen2vl_merger) { - clip_image_u8 * resized = clip_image_u8_init(); - auto patch_size = clip_patch_size(ctx) * 2; + clip_image_u8 resized; + auto patch_size = clip_get_patch_size(ctx) * 2; int nx = ceil((float)img->nx / patch_size) * patch_size; int ny = ceil((float)img->ny / patch_size) * patch_size; - bicubic_resize(*img, *resized, nx, ny); + bicubic_resize(*img, resized, nx, ny); - res_imgs->data = new clip_image_f32[1]; - // clip_image_f32 * res = clip_image_f32_init(); - normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + // clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std); // res_imgs->data[0] = *res; - res_imgs->size = 1; - - // clip_image_f32_free(res); - clip_image_u8_free(resized); + res_imgs->entries.push_back(std::move(img_f32)); return true; } - if (ctx->has_glm_projector) { - res_imgs->size = 1; - res_imgs->data = new clip_image_f32[res_imgs->size]; + if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { clip_image_u8 resized_image; int32_t sz=ctx->vision_model.hparams.image_size; bicubic_resize(*img, resized_image,sz,sz); - clip_image_f32 * res = clip_image_f32_init(); + clip_image_f32_ptr img_f32(clip_image_f32_init()); //clip_image_save_to_bmp(resized_image, "resized.bmp"); - normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std); - res_imgs->data[0] = *res; - clip_image_f32_free(res); + normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(img_f32)); return true; } bool pad_to_square = true; if (!ctx->has_vision_encoder) { - LOG_ERR("This gguf file seems to have no vision encoder\n"); + LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); return false; } auto & params = ctx->vision_model.hparams; // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing - if (strcmp(params.mm_patch_merge_type, "spatial_unpad") == 0) { + if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) { pad_to_square = false; } // free the previous res_imgs if any set - if (res_imgs->size > 0) { - clip_image_f32_batch_free(res_imgs); - } - res_imgs->data = nullptr; - res_imgs->size = 0; + res_imgs->entries.clear(); // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 - clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily + clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily if (pad_to_square && img->nx != img->ny) { int longer_side = std::max(img->nx, img->ny); temp->nx = longer_side; @@ -2313,28 +2189,18 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli // clip_image_u8_free(temp2); // } - std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) + std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) - clip_image_u8 *image_original_resize = clip_image_u8_init(); + clip_image_u8_ptr image_original_resize(clip_image_u8_init()); // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - patches.insert(patches.begin(), image_original_resize); - // clip_image_f32_batch_init(patches.size()); - res_imgs->size = patches.size(); - res_imgs->data = new clip_image_f32[res_imgs->size]; - int num=0; - for (auto& patch : patches) { - normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std); - num++; + patches.insert(patches.begin(), std::move(image_original_resize)); + for (auto & patch : patches) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*patch, *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); } - for (size_t i = 0; i < patches.size(); i++) { - // LOG_DBG("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny); - clip_image_u8_free(patches[i]); - } - - clip_image_u8_free(temp); - return true; } else { temp->nx = img->nx; @@ -2350,7 +2216,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli const int nx2 = ctx->vision_model.hparams.image_size; const int ny2 = ctx->vision_model.hparams.image_size; - clip_image_f32 * res = clip_image_f32_init(); + clip_image_f32_ptr res(clip_image_f32_init()); res->nx = nx2; res->ny = ny2; res->buf.resize(3 * nx2 * ny2); @@ -2402,7 +2268,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli } } } - clip_image_u8_free(temp); // { // clip_image_u8 * temp2 = clip_image_u8_init(); @@ -2412,10 +2277,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli // } // res_imgs.push_back(res); - res_imgs->size = 1; - res_imgs->data = new clip_image_f32[res_imgs->size]; - res_imgs->data[0] = *res; - clip_image_f32_free(res); + res_imgs->entries.push_back(std::move(res)); return true; } @@ -2425,12 +2287,9 @@ ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) { } void clip_free(clip_ctx * ctx) { - ggml_free(ctx->ctx_data); - gguf_free(ctx->ctx_gguf); - - ggml_backend_buffer_free(ctx->params_buffer); - ggml_backend_free(ctx->backend); - ggml_gallocr_free(ctx->compute_alloc); + if (ctx == nullptr) { + return; + } delete ctx; } @@ -2446,20 +2305,20 @@ size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float); } -int32_t clip_image_size(const struct clip_ctx * ctx) { +int32_t clip_get_image_size(const struct clip_ctx * ctx) { return ctx->vision_model.hparams.image_size; } -int32_t clip_patch_size(const struct clip_ctx * ctx) { +int32_t clip_get_patch_size(const struct clip_ctx * ctx) { return ctx->vision_model.hparams.patch_size; } -int32_t clip_hidden_size(const struct clip_ctx * ctx) { +int32_t clip_get_hidden_size(const struct clip_ctx * ctx) { return ctx->vision_model.hparams.hidden_size; } const char * clip_patch_merge_type(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.mm_patch_merge_type; + return ctx->vision_model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat"; } const int32_t * clip_image_grid(const struct clip_ctx * ctx) { @@ -2502,6 +2361,8 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); n_patches = x_patch * y_patch; + } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + n_patches = 256; } return n_patches; @@ -2595,23 +2456,27 @@ static std::vector> get_2d_sincos_pos_embed(int embed_dim, co bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { if (!ctx->has_vision_encoder) { - LOG_ERR("This gguf file seems to have no vision encoder\n"); + LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); return false; } - clip_image_f32_batch imgs{}; - imgs.size = 1; - imgs.data = img; + clip_image_f32_batch imgs; + clip_image_f32_ptr img_copy(clip_image_f32_init()); + *img_copy = *img; + imgs.entries.push_back(std::move(img_copy)); + return clip_image_batch_encode(ctx, n_threads, &imgs, vec); } -bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) { +bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) { + const clip_image_f32_batch & imgs = *imgs_c_ptr; + if (!ctx->has_vision_encoder) { - LOG_ERR("This gguf file seems to have no vision encoder\n"); + LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); return false; } - int batch_size = imgs->size; + int batch_size = imgs.entries.size(); if (ctx->has_llava_projector) { GGML_ASSERT(batch_size == 1); // TODO: support multiple images } @@ -2626,8 +2491,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // build the inference graph + ggml_backend_sched_reset(ctx->sched.get()); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); - ggml_gallocr_alloc_graph(ctx->compute_alloc, gf); + ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); // set inputs const auto & model = ctx->vision_model; @@ -2637,25 +2503,22 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima int image_size_width = image_size; int image_size_height = image_size; if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) { - image_size_width = imgs->data[0].nx; - image_size_height = imgs->data[0].ny; + image_size_width = imgs.entries[0]->nx; + image_size_height = imgs.entries[0]->ny; } const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); - const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); - if(ctx->load_image_size==nullptr){ - ctx->load_image_size= clip_image_size_init(); - } - const int pos_w = ctx->load_image_size->width/patch_size; - const int pos_h = ctx->load_image_size->height/patch_size; + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); + const int pos_w = ctx->load_image_size.width / patch_size; + const int pos_h = ctx->load_image_size.height / patch_size; { struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); float * data = (float *)malloc(ggml_nbytes(inp_raw)); - for (size_t i = 0; i < imgs->size; i++) { - const int nx = imgs->data[i].nx; - const int ny = imgs->data[i].ny; + for (size_t i = 0; i < imgs.entries.size(); i++) { + const int nx = imgs.entries[i]->nx; + const int ny = imgs.entries[i]->ny; if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) { GGML_ASSERT(nx == image_size && ny == image_size); } @@ -2666,7 +2529,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (int k = 0; k < 3; k++) { for (int y = 0; y < ny; y++) { for (int x = 0; x < nx; x++) { - data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k]; + data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k]; } } } @@ -2727,16 +2590,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima free(pos_embed_data); } } - else{ - { - if (ctx->has_class_embedding) { - struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); + else { + if (model.class_embedding) { + struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); - void* zero_mem = malloc(ggml_nbytes(embeddings)); - memset(zero_mem, 0, ggml_nbytes(embeddings)); - ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); - free(zero_mem); - } + void* zero_mem = malloc(ggml_nbytes(embeddings)); + memset(zero_mem, 0, ggml_nbytes(embeddings)); + ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); + free(zero_mem); } if (ctx->has_qwen2vl_merger) { @@ -2766,6 +2627,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); } + else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + // do nothing + } else { struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); @@ -2781,7 +2645,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // The patches vector is used to get rows to index into the embeds with; // we should skip dim 0 only if we have CLS to avoid going out of bounds // when retrieving the rows. - int patch_offset = ctx->has_class_embedding ? 1 : 0; + int patch_offset = model.class_embedding ? 1 : 0; int* patches_data = (int*)malloc(ggml_nbytes(patches)); for (int i = 0; i < num_patches; i++) { patches_data[i] = i + patch_offset; @@ -2792,11 +2656,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - if (ggml_backend_is_cpu(ctx->backend)) { - ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); - } + ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); - ggml_backend_graph_compute(ctx->backend, gf); + auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf); + if (status != GGML_STATUS_SUCCESS) { + LOG_ERR("%s: ggml_backend_sched_graph_compute failed with error %d\n", __func__, status); + return false; + } // the last node is the embedding tensor struct ggml_tensor * embeddings = ggml_graph_node(gf, -1); @@ -2818,10 +2684,13 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i assert(itype < GGML_TYPE_COUNT); ggml_type type = static_cast(itype); - auto * ctx_clip = clip_model_load(fname_inp, 2); + auto * ctx_clip = clip_init(fname_inp, clip_context_params{ + /* use_gpu */ false, + /* verbosity */ GGML_LOG_LEVEL_ERROR, + }); - const auto & ctx_src = ctx_clip->ctx_gguf; - const auto & ctx_data = ctx_clip->ctx_data; + const auto & ctx_src = ctx_clip->ctx_gguf.get(); + const auto & ctx_data = ctx_clip->ctx_data.get(); auto * ctx_out = gguf_init_empty(); gguf_set_kv(ctx_out, ctx_src); @@ -2895,7 +2764,7 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i f32_data = (float *)conv_buf.data(); break; default: - LOG_ERR("Please use an input file in f32 or f16\n"); + LOG_ERR("%s: Please use an input file in f32 or f16\n", __func__); gguf_free(ctx_out); return false; } @@ -2976,9 +2845,12 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { return ctx->vision_model.mm_1_b->ne[0]; } + if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + return ctx->vision_model.mm_input_proj_w->ne[0]; + } std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; - throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); + throw std::runtime_error(string_format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); } int clip_is_minicpmv(const struct clip_ctx * ctx) { @@ -2991,10 +2863,19 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) { bool clip_is_glm(const struct clip_ctx * ctx) { return ctx->has_glm_projector; } + bool clip_is_qwen2vl(const struct clip_ctx * ctx) { return ctx->has_qwen2vl_merger; } +bool clip_is_llava(const struct clip_ctx * ctx) { + return ctx->has_llava_projector; +} + +bool clip_is_gemma3(const struct clip_ctx * ctx) { + return ctx->proj_type == PROJECTOR_TYPE_GEMMA3; +} + // Determine the number of encoder layers to iterate over int get_deepest_feature_layer(const struct clip_ctx * ctx) { // Get the index of the second to last layer; this is the @@ -3030,3 +2911,11 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, clip_image_encode(ctx, n_threads, &clip_img, vec); return true; } + +// +// API used internally with mtmd +// + +projector_type clip_get_projector_type(const struct clip_ctx * ctx) { + return ctx->proj_type; +} diff --git a/llama/llama.cpp/examples/llava/clip.h b/llama/llama.cpp/examples/llava/clip.h index f9f80d7d9..5fc45d3e2 100644 --- a/llama/llama.cpp/examples/llava/clip.h +++ b/llama/llama.cpp/examples/llava/clip.h @@ -1,6 +1,7 @@ #ifndef CLIP_H #define CLIP_H +#include "ggml.h" #include #include @@ -29,27 +30,28 @@ struct clip_image_size { int height; }; -struct clip_image_u8_batch { - struct clip_image_u8 * data; - size_t size; +struct clip_image_f32; +struct clip_image_u8_batch; +struct clip_image_f32_batch; + +struct clip_context_params { + bool use_gpu; + enum ggml_log_level verbosity; }; -struct clip_image_f32_batch { - struct clip_image_f32 * data; - size_t size; -}; +// deprecated, use clip_init +CLIP_API struct clip_ctx * clip_model_load(const char * fname, int verbosity); -CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity); -CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity); +CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params); CLIP_API void clip_free(struct clip_ctx * ctx); CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx); CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w); -CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx); -CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx); -CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx); +CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx); +CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx); +CLIP_API int32_t clip_get_hidden_size(const struct clip_ctx * ctx); // TODO: should be enum, not string CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); @@ -65,16 +67,30 @@ CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); -CLIP_API struct clip_image_size * clip_image_size_init(); -CLIP_API struct clip_image_u8 * clip_image_u8_init (); -CLIP_API struct clip_image_f32 * clip_image_f32_init(); +CLIP_API struct clip_image_size * clip_image_size_init(); +CLIP_API struct clip_image_u8 * clip_image_u8_init (); +CLIP_API struct clip_image_f32 * clip_image_f32_init(); +CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava +// nx, ny are the output image dimensions +CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny); + +CLIP_API void clip_image_size_free (struct clip_image_size * img_size); CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); -/** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */ +// use for accessing underlay data of clip_image_f32_batch +CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size() +CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx +CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny +CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data + +/** + * Build image from pixels decoded by other libraries instead of stb_image.h for better performance. + * The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes + */ CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img); CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); @@ -95,6 +111,8 @@ CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); CLIP_API bool clip_is_glm(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); +CLIP_API bool clip_is_llava(const struct clip_ctx * ctx); +CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx); CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx); diff --git a/llama/llama.cpp/examples/llava/llava.cpp b/llama/llama.cpp/examples/llava/llava.cpp index f0e484a1c..5eb40bcd1 100644 --- a/llama/llama.cpp/examples/llava/llava.cpp +++ b/llama/llama.cpp/examples/llava/llava.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #if defined(LLAVA_LOG_OFF) # define LOG_INF(...) @@ -45,6 +46,17 @@ struct clip_image_grid_shape { int second; }; +// convenience cpp wrapper +struct clip_image_f32_batch_deleter { + void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } +}; +typedef std::unique_ptr clip_image_f32_batch_ptr; + +struct clip_image_size_deleter { + void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } +}; +typedef std::unique_ptr clip_image_size_ptr; + /** * Selects the best resolution from a list of possible resolutions based on the original size. * @@ -105,8 +117,8 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector struct ggml_context * ctx; } model; - const int32_t image_size = clip_image_size(ctx_clip); - const int32_t patch_size = clip_patch_size(ctx_clip); + const int32_t image_size = clip_get_image_size(ctx_clip); + const int32_t patch_size = clip_get_patch_size(ctx_clip); int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches) @@ -246,12 +258,9 @@ static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { // std::vector img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 - clip_image_f32_batch img_res_v; - img_res_v.size = 0; - img_res_v.data = nullptr; - if (!clip_image_preprocess(ctx_clip, img, &img_res_v)) { + clip_image_f32_batch_ptr img_res_v(clip_image_f32_batch_init()); + if (!clip_image_preprocess(ctx_clip, img, img_res_v.get())) { LOG_ERR("%s: unable to preprocess image\n", __func__); - delete[] img_res_v.data; return false; } @@ -259,66 +268,72 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); + const size_t n_imgs = clip_image_f32_batch_n_images(img_res_v.get()); + if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) { std::vector image_embd_v; - image_embd_v.resize(img_res_v.size); - struct clip_image_size * load_image_size = clip_image_size_init(); + image_embd_v.resize(n_imgs); + clip_image_size load_image_size; - for (size_t i = 0; i < img_res_v.size; i++) { + for (size_t i = 0; i < n_imgs; i++) { const int64_t t_img_enc_step_start_us = ggml_time_us(); - image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny)); - int patch_size=14; - load_image_size->width = img_res_v.data[i].nx; - load_image_size->height = img_res_v.data[i].ny; - clip_add_load_image_size(ctx_clip, load_image_size); + int nx = clip_image_f32_batch_nx(img_res_v.get(), i); + int ny = clip_image_f32_batch_ny(img_res_v.get(), i); + image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, nx, ny)); + int patch_size = 14; + load_image_size.width = nx; + load_image_size.height = ny; + clip_add_load_image_size(ctx_clip, &load_image_size); bool encoded = false; + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); if (clip_is_qwen2vl(ctx_clip)) { - encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); + encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); } else { - encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); + encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(img_res, patch_size), image_embd_v[i]); } if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); + LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); return false; } const int64_t t_img_enc_steop_batch_us = ggml_time_us(); - LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)img_res_v.size, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); + LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)n_imgs, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); } const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); int n_img_pos_out = 0; for (size_t i = 0; i < image_embd_v.size(); i++) { + int nx = clip_image_f32_batch_nx(img_res_v.get(), i); + int ny = clip_image_f32_batch_ny(img_res_v.get(), i); + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); std::memcpy( image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), image_embd_v[i], - clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny)); - n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]); + clip_embd_nbytes_by_img(ctx_clip, nx, ny)); + n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res); } *n_img_pos = n_img_pos_out; for (size_t i = 0; i < image_embd_v.size(); i++) { free(image_embd_v[i]); } image_embd_v.clear(); - load_image_size->width = img->nx; - load_image_size->height = img->ny; - clip_add_load_image_size(ctx_clip, load_image_size); - LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); - delete[] img_res_v.data; - img_res_v.size = 0; - img_res_v.data = nullptr; + load_image_size.width = img->nx; + load_image_size.height = img->ny; + clip_add_load_image_size(ctx_clip, &load_image_size); + LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size.width, load_image_size.height); } else if (clip_is_glm(ctx_clip)){ struct clip_image_size * load_image_size = clip_image_size_init(); - load_image_size->width = img_res_v.data[0].nx; - load_image_size->height = img_res_v.data[0].ny; + load_image_size->width = clip_image_f32_batch_nx(img_res_v.get(), 0); + load_image_size->height = clip_image_f32_batch_ny(img_res_v.get(), 0); clip_add_load_image_size(ctx_clip, load_image_size); - bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); - int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2); + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); + bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); + int pos = int(load_image_size->width/clip_get_patch_size(ctx_clip)/2); *n_img_pos = (pos * pos + 2); if (!encoded){ LOG_ERR("Unable to encode image \n"); @@ -328,8 +343,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding *n_img_pos = clip_n_patches(ctx_clip); - bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096 - delete[] img_res_v.data; + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); + bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096 if (!encoded) { LOG_ERR("Unable to encode image\n"); @@ -340,17 +355,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli // spatial_unpad llava-1.6 type embedding // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working std::vector image_embd_v; - image_embd_v.resize(img_res_v.size); - for (size_t i = 0; i < img_res_v.size; i++) { + image_embd_v.resize(n_imgs); + for (size_t i = 0; i < n_imgs; i++) { + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 - const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside + const bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); + LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); return false; } } const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); const int32_t * image_grid = clip_image_grid(ctx_clip); const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip); @@ -360,12 +376,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); } - // free all img_res_v - not needed anymore - delete[] img_res_v.data; - img_res_v.size = 0; - img_res_v.data = nullptr; - - const int32_t image_size = clip_image_size(ctx_clip); + const int32_t image_size = clip_get_image_size(ctx_clip); struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size); diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h index 167747111..f91896e48 100644 --- a/llama/llama.cpp/include/llama.h +++ b/llama/llama.cpp/include/llama.h @@ -60,6 +60,7 @@ extern "C" { struct llama_model; struct llama_context; struct llama_sampler; + struct llama_kv_cache; typedef int32_t llama_pos; typedef int32_t llama_token; @@ -106,6 +107,10 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, }; enum llama_rope_type { @@ -277,10 +282,18 @@ extern "C" { }; }; + struct llama_model_tensor_buft_override { + const char * pattern; + ggml_backend_buffer_type_t buft; + }; + struct llama_model_params { // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) ggml_backend_dev_t * devices; + // NULL-terminated list of buffer types to use for tensors that match a pattern + const struct llama_model_tensor_buft_override * tensor_buft_overrides; + int32_t n_gpu_layers; // number of layers to store in VRAM enum llama_split_mode split_mode; // how to split the model across multiple GPUs @@ -356,17 +369,18 @@ extern "C" { // model quantization parameters typedef struct llama_model_quantize_params { - int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() - enum llama_ftype ftype; // quantize to this llama_ftype - enum ggml_type output_tensor_type; // output tensor type - enum ggml_type token_embedding_type; // token embeddings tensor type - bool allow_requantize; // allow quantizing non-f32/f16 tensors - bool quantize_output_tensor; // quantize output.weight - bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored - bool pure; // quantize all tensors to the default type - bool keep_split; // quantize to the same number of shards - void * imatrix; // pointer to importance matrix data - void * kv_overrides; // pointer to vector containing overrides + int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + enum ggml_type output_tensor_type; // output tensor type + enum ggml_type token_embedding_type; // token embeddings tensor type + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool pure; // quantize all tensors to the default type + bool keep_split; // quantize to the same number of shards + void * imatrix; // pointer to importance matrix data + void * kv_overrides; // pointer to vector containing overrides + void * tensor_types; // pointer to vector containing tensor types } llama_model_quantize_params; typedef struct llama_logit_bias { @@ -475,7 +489,8 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); - LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); + LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx); + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); @@ -592,7 +607,7 @@ extern "C" { // KV cache // - // TODO: remove llama_kv_cache_view_* API + // TODO: start using struct llama_kv_cache // Information associated with an individual cell in the KV cache view. struct llama_kv_cache_view_cell { @@ -647,13 +662,19 @@ extern "C" { // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times - LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); + LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx); + + DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx), + "use llama_kv_self_n_tokens instead"); // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) - LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); + LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx); + + DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx), + "use llama_kv_self_used_cells instead"); // Clear the KV cache - both cell info is erased and KV data is zeroed - LLAMA_API void llama_kv_cache_clear( + LLAMA_API void llama_kv_self_clear( struct llama_context * ctx); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) @@ -661,7 +682,7 @@ extern "C" { // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API bool llama_kv_cache_seq_rm( + LLAMA_API bool llama_kv_self_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -671,7 +692,7 @@ extern "C" { // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_cp( + LLAMA_API void llama_kv_self_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, @@ -679,17 +700,17 @@ extern "C" { llama_pos p1); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_cache_seq_keep( + LLAMA_API void llama_kv_self_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() - // - explicitly with llama_kv_cache_update() + // - explicitly with llama_kv_self_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_add( + LLAMA_API void llama_kv_self_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -699,10 +720,10 @@ extern "C" { // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() - // - explicitly with llama_kv_cache_update() + // - explicitly with llama_kv_self_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_div( + LLAMA_API void llama_kv_self_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -710,24 +731,76 @@ extern "C" { int d); // Returns the largest position present in the KV cache for the specified sequence - LLAMA_API llama_pos llama_kv_cache_seq_pos_max( + LLAMA_API llama_pos llama_kv_self_seq_pos_max( struct llama_context * ctx, - llama_seq_id seq_id); - - // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache - // how to avoid this? + llama_seq_id seq_id); // Defragment the KV cache // This will be applied: // - lazily on next llama_decode() - // - explicitly with llama_kv_cache_update() - LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx); - - // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) - LLAMA_API void llama_kv_cache_update(struct llama_context * ctx); + // - explicitly with llama_kv_self_update() + LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx); // Check if the context supports KV cache shifting - LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx); + LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); + + // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) + LLAMA_API void llama_kv_self_update(struct llama_context * ctx); + + DEPRECATED(LLAMA_API void llama_kv_cache_clear( + struct llama_context * ctx), + "use llama_kv_self_clear instead"); + + DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1), + "use llama_kv_self_seq_rm instead"); + + DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1), + "use llama_kv_self_seq_cp instead"); + + DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_kv_self_seq_keep instead"); + + DEPRECATED(LLAMA_API void llama_kv_cache_seq_add( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta), + "use llama_kv_self_seq_add instead"); + + DEPRECATED(LLAMA_API void llama_kv_cache_seq_div( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d), + "use llama_kv_self_seq_div instead"); + + DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_kv_self_seq_pos_max instead"); + + DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx), + "use llama_kv_self_defrag instead"); + + DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx), + "use llama_kv_self_can_shift instead"); + + DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx), + "use llama_kv_self_update instead"); + // // State / sessions @@ -891,6 +964,10 @@ extern "C" { // If set to true, the model will only attend to the past tokens LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); + // Set whether the model is in warmup mode or not + // If true, all model tensors are activated during llama_decode() to load and cache their weights. + LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); + // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); @@ -1206,22 +1283,38 @@ extern "C" { float tau, float eta); + /// @details Intializes a GBNF grammar, see grammars/README.md for details. + /// @param vocab The vocabulary that this grammar will be used with. + /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. + /// @param grammar_root The name of the start symbol for the grammar. LLAMA_API struct llama_sampler * llama_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); - /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 - /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. - /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. - LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( + DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, - size_t num_trigger_tokens); + size_t num_trigger_tokens), + "use llama_sampler_init_grammar_lazy_patterns instead"); + + + /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 + /// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( diff --git a/llama/llama.cpp/src/llama-adapter.cpp b/llama/llama.cpp/src/llama-adapter.cpp index 8a0800463..7ac54d239 100644 --- a/llama/llama.cpp/src/llama-adapter.cpp +++ b/llama/llama.cpp/src/llama-adapter.cpp @@ -4,14 +4,13 @@ #include "llama-mmap.h" #include "llama-model.h" -#include #include #include #include // vec -struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { +ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { return nullptr; } @@ -19,7 +18,7 @@ struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { return tensors[il]; } -struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { +ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const { ggml_tensor * layer_dir = tensor_for(il); if (layer_dir != nullptr) { cur = ggml_add(ctx, cur, layer_dir); @@ -40,7 +39,7 @@ bool llama_adapter_cvec::init(const llama_model & model) { auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { - struct ggml_init_params params = { + ggml_init_params params = { /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, @@ -91,7 +90,7 @@ bool llama_adapter_cvec::init(const llama_model & model) { return true; } -int32_t llama_adapter_cvec::apply( +bool llama_adapter_cvec::apply( const llama_model & model, const float * data, size_t len, @@ -104,17 +103,17 @@ int32_t llama_adapter_cvec::apply( // disable the current control vector (but leave allocated for later) layer_start = -1; layer_end = -1; - return 0; + return true; } if (n_embd != (int) hparams.n_embd) { LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__); - return 1; + return false; } if (tensors.empty()) { if (!init(model)) { - return 1; + return false; } } @@ -130,12 +129,12 @@ int32_t llama_adapter_cvec::apply( } } - return 0; + return true; } // lora -llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) { +llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) { const std::string name(w->name); const auto pos = ab_map.find(name); @@ -146,11 +145,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * return nullptr; } -static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) { +static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) { LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); ggml_context * ctx_init; - struct gguf_init_params meta_gguf_params = { + gguf_init_params meta_gguf_params = { /* .no_alloc = */ true, /* .ctx = */ &ctx_init, }; @@ -201,7 +200,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char auto it = ctx_map.find(buft); if (it == ctx_map.end()) { // add a new context - struct ggml_init_params params = { + ggml_init_params params = { /*.mem_size =*/ n_tensors*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, @@ -248,6 +247,26 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char } } + // get extra buffer types of the CPU + // TODO: a more general solution for non-CPU extra buft should be imlpemented in the future + // ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948 + std::vector buft_extra; + { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_extra.emplace_back(*extra_bufts); + ++extra_bufts; + } + } + } + // add tensors for (auto & it : ab_map) { const std::string & name = it.first; @@ -264,7 +283,23 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); } - struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); + auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer); + + // do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case + for (auto & ex : buft_extra) { + if (ex == buft) { + LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); + + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + buft = ggml_backend_dev_buffer_type(cpu_dev); + + break; + } + } + + LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); + + ggml_context * dev_ctx = ctx_for_buft(buft); // validate tensor shape if (is_token_embd) { // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() @@ -281,8 +316,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char } // save tensor to adapter - struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); - struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); + ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); + ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); ggml_set_name(tensor_a, w.a->name); ggml_set_name(tensor_b, w.b->name); adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b); @@ -308,7 +343,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char { llama_file gguf_file(path_lora, "rb"); std::vector read_buf; - auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) { + auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) { size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name)); size_t size = ggml_nbytes(orig); read_buf.resize(size); @@ -327,8 +362,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } -struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) { - struct llama_adapter_lora * adapter = new llama_adapter_lora(); +llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { + llama_adapter_lora * adapter = new llama_adapter_lora(); try { llama_adapter_lora_init_impl(*model, path_lora, *adapter); @@ -342,6 +377,6 @@ struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, return nullptr; } -void llama_adapter_lora_free(struct llama_adapter_lora * adapter) { +void llama_adapter_lora_free(llama_adapter_lora * adapter) { delete adapter; } diff --git a/llama/llama.cpp/src/llama-adapter.h b/llama/llama.cpp/src/llama-adapter.h index 603fa08f6..65824e972 100644 --- a/llama/llama.cpp/src/llama-adapter.h +++ b/llama/llama.cpp/src/llama-adapter.h @@ -15,11 +15,11 @@ // struct llama_adapter_cvec { - struct ggml_tensor * tensor_for(int il) const; + ggml_tensor * tensor_for(int il) const; - struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const; + ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const; - int32_t apply( + bool apply( const llama_model & model, const float * data, size_t len, @@ -36,7 +36,7 @@ private: std::vector ctxs; std::vector bufs; - std::vector tensors; // per layer + std::vector tensors; // per layer }; // @@ -44,8 +44,8 @@ private: // struct llama_adapter_lora_weight { - struct ggml_tensor * a = nullptr; - struct ggml_tensor * b = nullptr; + ggml_tensor * a = nullptr; + ggml_tensor * b = nullptr; // get actual scale based on rank and alpha float get_scale(float alpha, float adapter_scale) const { @@ -55,12 +55,12 @@ struct llama_adapter_lora_weight { } llama_adapter_lora_weight() = default; - llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} + llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {} }; struct llama_adapter_lora { // map tensor name to lora_a_b - std::unordered_map ab_map; + std::unordered_map ab_map; std::vector ctxs; std::vector bufs; @@ -70,5 +70,7 @@ struct llama_adapter_lora { llama_adapter_lora() = default; ~llama_adapter_lora() = default; - llama_adapter_lora_weight * get_weight(struct ggml_tensor * w); + llama_adapter_lora_weight * get_weight(ggml_tensor * w); }; + +using llama_adapter_loras = std::unordered_map; diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index 13a0a9888..bdf3d898b 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -7,6 +7,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_MLLAMA, "mllama" }, + { LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_DECI, "deci" }, { LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_GROK, "grok" }, @@ -26,6 +27,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_QWEN2MOE, "qwen2moe" }, { LLM_ARCH_QWEN2VL, "qwen2vl" }, + { LLM_ARCH_QWEN3, "qwen3" }, + { LLM_ARCH_QWEN3MOE, "qwen3moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -52,6 +55,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -60,11 +64,15 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, + { LLM_ARCH_RWKV7, "rwkv7" }, + { LLM_ARCH_ARWKV7, "arwkv7" }, { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_SOLAR, "solar" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_PLM, "plm" }, + { LLM_ARCH_BAILINGMOE, "bailingmoe" }, { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -74,6 +82,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, { LLM_KV_GENERAL_NAME, "general.name" }, { LLM_KV_GENERAL_AUTHOR, "general.author" }, { LLM_KV_GENERAL_VERSION, "general.version" }, @@ -112,25 +121,30 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, + { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, - { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, - { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, - { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, - { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, - { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, - { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, - { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, - { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, - { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, - { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, - { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, - { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, - { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, - { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, - { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, - { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, - { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, - { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, + { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" }, + { LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" }, + { LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" }, + { LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" }, + { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, + { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -229,6 +243,35 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_LLAMA4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_MLLAMA, { @@ -594,6 +637,45 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_QWEN3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN3MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_PHI2, { @@ -811,9 +893,12 @@ static const std::map> LLM_TENSOR_N { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, @@ -1073,6 +1158,22 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, + { + LLM_ARCH_PLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_CHATGLM, { @@ -1091,6 +1192,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_GLM4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_BITNET, { @@ -1275,6 +1395,74 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_RWKV7, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" }, + { LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" }, + { LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" }, + { LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" }, + { LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" }, + { LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" }, + { LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" }, + { LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" }, + { LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" }, + { LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" }, + { LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" }, + { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" }, + { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, + }, + }, + { + LLM_ARCH_ARWKV7, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" }, + { LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" }, + { LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" }, + { LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" }, + { LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" }, + { LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" }, + { LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" }, + { LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" }, + { LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" }, + { LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" }, + { LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_GRANITE, { @@ -1372,6 +1560,29 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, + { + LLM_ARCH_BAILINGMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_MISTRAL3, { @@ -1468,6 +1679,12 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -1486,6 +1703,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, @@ -1493,6 +1713,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index 8476ae0a1..ee081fbfe 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -10,6 +10,7 @@ enum llm_arch { LLM_ARCH_LLAMA, + LLM_ARCH_LLAMA4, LLM_ARCH_MLLAMA, LLM_ARCH_DECI, LLM_ARCH_FALCON, @@ -30,6 +31,8 @@ enum llm_arch { LLM_ARCH_QWEN2, LLM_ARCH_QWEN2MOE, LLM_ARCH_QWEN2VL, + LLM_ARCH_QWEN3, + LLM_ARCH_QWEN3MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, @@ -56,6 +59,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, + LLM_ARCH_GLM4, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -64,12 +68,16 @@ enum llm_arch { LLM_ARCH_EXAONE, LLM_ARCH_RWKV6, LLM_ARCH_RWKV6QWEN2, + LLM_ARCH_RWKV7, + LLM_ARCH_ARWKV7, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, LLM_ARCH_SOLAR, LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_MISTRAL3, + LLM_ARCH_PLM, + LLM_ARCH_BAILINGMOE, LLM_ARCH_UNKNOWN, }; @@ -78,6 +86,7 @@ enum llm_kv { LLM_KV_GENERAL_ARCHITECTURE, LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_FILE_TYPE, LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_VERSION, @@ -116,6 +125,7 @@ enum llm_kv { LLM_KV_RESIDUAL_SCALE, LLM_KV_EMBEDDING_SCALE, LLM_KV_TOKEN_SHIFT_COUNT, + LLM_KV_INTERLEAVE_MOE_LAYER_STEP, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -130,6 +140,10 @@ enum llm_kv { LLM_KV_ATTENTION_CAUSAL, LLM_KV_ATTENTION_Q_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK, + LLM_KV_ATTENTION_DECAY_LORA_RANK, + LLM_KV_ATTENTION_ICLR_LORA_RANK, + LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, + LLM_KV_ATTENTION_GATE_LORA_RANK, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, @@ -248,6 +262,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_POST_ATTN_NORM, + LLM_TENSOR_POST_MLP_NORM, LLM_TENSOR_SSM_IN, LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, @@ -255,8 +271,20 @@ enum llm_tensor { LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, + LLM_TENSOR_TIME_MIX_A0, + LLM_TENSOR_TIME_MIX_A1, + LLM_TENSOR_TIME_MIX_A2, + LLM_TENSOR_TIME_MIX_V0, + LLM_TENSOR_TIME_MIX_V1, + LLM_TENSOR_TIME_MIX_V2, + LLM_TENSOR_TIME_MIX_G1, + LLM_TENSOR_TIME_MIX_G2, + LLM_TENSOR_TIME_MIX_K_K, + LLM_TENSOR_TIME_MIX_K_A, + LLM_TENSOR_TIME_MIX_R_K, LLM_TENSOR_TIME_MIX_LERP_X, LLM_TENSOR_TIME_MIX_LERP_W, LLM_TENSOR_TIME_MIX_LERP_K, diff --git a/llama/llama.cpp/src/llama-batch.h b/llama/llama.cpp/src/llama-batch.h index 773c3808b..f1df40d27 100644 --- a/llama/llama.cpp/src/llama-batch.h +++ b/llama/llama.cpp/src/llama-batch.h @@ -42,9 +42,9 @@ struct llama_sbatch { bool logits_all; // TODO: remove once lctx.logits_all is removed too // sorted indices into the batch - std::vector ids; + std::vector ids; // batch indices of the output - std::vector out_ids; + std::vector out_ids; std::vector seq; const llama_batch * batch = nullptr; diff --git a/llama/llama.cpp/src/llama-chat.cpp b/llama/llama.cpp/src/llama-chat.cpp index 028a64794..721faa4e8 100644 --- a/llama/llama.cpp/src/llama-chat.cpp +++ b/llama/llama.cpp/src/llama-chat.cpp @@ -4,6 +4,7 @@ #include #include +#include #if __cplusplus >= 202000L #define LU8(x) (const char*)(u8##x) @@ -58,6 +59,9 @@ static const std::map LLM_CHAT_TEMPLATES = { { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, + { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, + { "bailing", LLM_CHAT_TEMPLATE_BAILING }, + { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -167,6 +171,12 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GIGACHAT; } else if (tmpl_contains("<|role_start|>")) { return LLM_CHAT_TEMPLATE_MEGREZ; + } else if (tmpl_contains(" Ассистент:")) { + return LLM_CHAT_TEMPLATE_YANDEX; + } else if (tmpl_contains("ASSISTANT") && tmpl_contains("'HUMAN'")) { + return LLM_CHAT_TEMPLATE_BAILING; + } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { + return LLM_CHAT_TEMPLATE_LLAMA4; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -566,7 +576,51 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|role_start|>assistant<|role_end|>"; } - } else { + } else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) { + // Yandex template ("\n\n" is defined as EOT token) + + ss << ""; + + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (role == "user") { + ss << " Пользователь: " << chat[i]->content << "\n\n"; + } else if (role == "assistant") { + ss << " Ассистент: " << chat[i]->content << "\n\n"; + } + } + + // Add generation prompt if needed + if (add_ass) { + ss << " Ассистент:[SEP]"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) { + // Bailing (Ling) template + for (auto message : chat) { + std::string role(message->role); + + if (role == "user") { + role = "HUMAN"; + } else { + std::transform(role.begin(), role.end(), role.begin(), ::toupper); + } + + ss << "" << role << "" << message->content; + } + + if (add_ass) { + ss << "ASSISTANT"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) { + // Llama 4 + for (auto message : chat) { + std::string role(message->role); + ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>"; + } + if (add_ass) { + ss << "<|header_start|>assistant<|header_end|>\n\n"; + } + } else { // template not supported return -1; } @@ -584,4 +638,3 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) { } return (int32_t) LLM_CHAT_TEMPLATES.size(); } - diff --git a/llama/llama.cpp/src/llama-chat.h b/llama/llama.cpp/src/llama-chat.h index 2f6a0e3e2..34537ca21 100644 --- a/llama/llama.cpp/src/llama-chat.h +++ b/llama/llama.cpp/src/llama-chat.h @@ -38,6 +38,9 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, + LLM_CHAT_TEMPLATE_YANDEX, + LLM_CHAT_TEMPLATE_BAILING, + LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 7b22fe137..d6e7b3af8 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -1,559 +1,1525 @@ #include "llama-context.h" #include "llama-impl.h" +#include "llama-io.h" #include "llama-mmap.h" +#include "llama-model.h" +#include "llama-kv-cache.h" #include -#include #include #include +#include -void llama_set_k_shift(struct llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; +// +// llama_context +// - assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); +llama_context::llama_context( + const llama_model & model, + llama_context_params params) : + model(model) { + LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); - int32_t * data = (int32_t *) lctx.inp_K_shift->data; + t_start_us = model.t_start_us; + t_load_us = model.t_load_us; - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].delta; - } -} + const auto & hparams = model.hparams; -void llama_set_s_copy(struct llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; + cparams.n_seq_max = std::max(1u, params.n_seq_max); + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.yarn_ext_factor = params.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow; + cparams.defrag_thold = params.defrag_thold; + cparams.embeddings = params.embeddings; + cparams.offload_kqv = params.offload_kqv; + cparams.flash_attn = params.flash_attn; + cparams.no_perf = params.no_perf; + cparams.pooling_type = params.pooling_type; + cparams.warmup = false; - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; + cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; - int32_t * data = (int32_t *) lctx.inp_s_copy->data; + cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : + hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : + hparams.n_ctx_train; - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; - } -} + cparams.cb_eval = params.cb_eval; + cparams.cb_eval_user_data = params.cb_eval_user_data; -// llama input - -static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { - // TODO move to hparams if a T5 variant appears that uses a different value - const int64_t max_distance = 128; - - if (bidirectional) { - n_buckets >>= 1; + auto rope_scaling_type = params.rope_scaling_type; + if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { + rope_scaling_type = hparams.rope_scaling_type_train; } - const int64_t max_exact = n_buckets >> 1; + if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { + cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none + } - int32_t relative_position = x - y; - int32_t relative_bucket = 0; - if (bidirectional) { - relative_bucket += (relative_position > 0) * n_buckets; - relative_position = abs(relative_position); + if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' + cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; + } + + cparams.yarn_attn_factor *= hparams.rope_attn_factor; + + if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { + if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { + cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; + } else { + cparams.pooling_type = hparams.pooling_type; + } + } + + if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { + cparams.causal_attn = hparams.causal_attn; } else { - relative_position = -std::min(relative_position, 0); - } - int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); - relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); - relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); - return relative_bucket; -} - -void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { - // - // set input data - // - - const auto & hparams = lctx.model.hparams; - const auto & cparams = lctx.cparams; - const auto & kv_self = lctx.kv_self; - - if (ubatch.token) { - const int64_t n_tokens = ubatch.n_tokens; - - ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens)); + cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; } - if (ubatch.embd) { - if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) { - ggml_backend_tensor_set(lctx.inp_cross_attn_state, ubatch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state)); - // zero out inp_embd since it's not used - float * inp_embd_data = (float *)lctx.inp_embd->data; - for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) { - inp_embd_data[i] = 0.0f; + // with causal attention, the batch size is limited by the context size + cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; + + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self + if (cparams.n_batch < GGML_KQ_MASK_PAD) { + LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); + cparams.n_batch = GGML_KQ_MASK_PAD; + } + + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + + LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq); + LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); + LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + + if (n_ctx_per_seq < hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", + __func__, n_ctx_per_seq, hparams.n_ctx_train); + } + + if (n_ctx_per_seq > hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", + __func__, n_ctx_per_seq, hparams.n_ctx_train); + } + + logits_all = params.logits_all; + + if (!hparams.vocab_only) { + // GPU backends + for (auto * dev : model.devices) { + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (backend == nullptr) { + throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); } - } else { - const int64_t n_embd = hparams.n_embd; - const int64_t n_tokens = ubatch.n_tokens; + backends.emplace_back(backend); + } - ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); + // add ACCEL backends (such as BLAS) + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (backend == nullptr) { + throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); + } + backends.emplace_back(backend); + } + } + + // add CPU backend + backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (backend_cpu == nullptr) { + throw std::runtime_error("failed to initialize CPU backend"); + } + backends.emplace_back(backend_cpu); + + // create a list of the set_n_threads functions in the backends + for (auto & backend : backends) { + ggml_backend_dev_t dev = ggml_backend_get_device(backend.get()); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn); + } + } + } + + llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data); + + // graph outputs buffer + { + // resized during inference when a batch uses more outputs + if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) { + throw std::runtime_error("failed to reserve initial output buffer"); + } + + LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__, + ggml_backend_buffer_name (buf_output.get()), + ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0); } } - if (ubatch.pos && lctx.inp_pos) { - const int64_t n_tokens = ubatch.n_tokens; - auto n_pos = lctx.n_pos_per_token; - ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos)); + // init the memory module + // TODO: for now, always create a unified KV cache + if (!hparams.vocab_only) { + kv_self.reset(static_cast(model.create_memory())); + + LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams)); + + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + uint32_t kv_size = cparams.n_ctx; + ggml_type type_k = params.type_k; + ggml_type type_v = params.type_v; + + if (llama_model_is_recurrent(&model)) { + // Mamba needs at least as many KV cells as there are sequences kept at any time + kv_size = std::max((uint32_t) 1, params.n_seq_max); + // it's probably best to keep as much precision as possible for the states + type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states + type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states + } + + GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); + GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); + + if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { + throw std::runtime_error("failed to initialize self-attention cache"); + } + + { + const size_t memory_size_k = kv_self->size_k_bytes(); + const size_t memory_size_v = kv_self->size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } } - if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - //GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); + // init backends + if (!hparams.vocab_only) { + LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__); - if (!lctx.inp_out_ids) { - LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__); + backend_buft.clear(); + backend_ptrs.clear(); + + for (auto & backend : backends) { + auto * buft = ggml_backend_get_default_buffer_type(backend.get()); + auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); + + if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { + // use the host buffer of the first device CPU for faster transfer of the intermediate state + auto * dev = model.devices[0]; + auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + if (host_buft) { + buft = host_buft; + } + } + + backend_buft.push_back(buft); + backend_ptrs.push_back(backend.get()); + } + + LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); + + const size_t max_nodes = this->graph_max_nodes(); + + LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); + + // buffer used to store the computation graph and the tensor meta data + buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + + // TODO: move these checks to ggml_backend_sched + // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary + bool pipeline_parallel = + model.n_devices() > 1 && + model.params.n_gpu_layers > (int) model.hparams.n_layer && + model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && + cparams.offload_kqv && + !model.has_tensor_overrides(); + + // pipeline parallelism requires support for async compute and events in all devices + if (pipeline_parallel) { + for (auto & backend : backends) { + auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { + // ignore CPU backend + continue; + } + auto * dev = ggml_backend_get_device(backend.get()); + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + if (!props.caps.async || !props.caps.events) { + // device does not support async compute or events + pipeline_parallel = false; + break; + } + } + } + + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel)); + + if (pipeline_parallel) { + LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); + } + } + + // reserve worst-case graph + if (!hparams.vocab_only) { + const uint32_t n_seqs = 1; // TODO: worst-case number of sequences + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + + // restore later + // TODO: something cleaner + const auto n_outputs_save = n_outputs; + + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + + int n_splits_pp = -1; + int n_nodes_pp = -1; + + int n_splits_tg = -1; + int n_nodes_tg = -1; + + // simulate full KV cache + kv_self->n = kv_self->size; + + cross.v_embd.clear(); + + // reserve pp graph first so that buffers are only allocated once + { + llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + // max number of outputs + n_outputs = ubatch_pp.n_tokens; + + LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); + + auto * gf = graph_init(); + graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); + + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf); + } + + // reserve with tg graph to get the number of splits and nodes + { + llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + n_outputs = ubatch_tg.n_tokens; + + LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs); + + auto * gf = graph_init(); + graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT); + + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + throw std::runtime_error("failed to allocate compute tg buffers"); + } + + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf); + } + + // reserve again with pp graph to avoid ggml-alloc reallocations during inference + { + llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + n_outputs = ubatch_pp.n_tokens; + + LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); + + auto * gf = graph_init(); + graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); + + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } + + n_outputs = n_outputs_save; + + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (size > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + size / 1024.0 / 1024.0); + } + } + + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); } else { - const int64_t n_tokens = ubatch.n_tokens; + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); - int32_t * data = (int32_t *) lctx.inp_out_ids->data; + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + } + } +} - if (lctx.n_outputs == n_tokens) { - for (int i = 0; i < n_tokens; ++i) { - data[i] = i; +llama_context::~llama_context() = default; + +void llama_context::synchronize() { + ggml_backend_sched_synchronize(sched.get()); + + // FIXME: if multiple single tokens are evaluated without a synchronization, + // the stats will be added to the prompt evaluation stats + // this should only happen when using batch size 1 to evaluate a batch + + // add the evaluation to the stats + if (n_queued_tokens == 1) { + if (!cparams.no_perf) { + t_eval_us += ggml_time_us() - t_compute_start_us; + } + n_eval++; + } else if (n_queued_tokens > 1) { + if (!cparams.no_perf) { + t_p_eval_us += ggml_time_us() - t_compute_start_us; + } + n_p_eval += n_queued_tokens; + } + + // get a more accurate load time, upon first eval + if (n_queued_tokens > 0 && !has_evaluated_once) { + t_load_us = ggml_time_us() - t_start_us; + has_evaluated_once = true; + } + + n_queued_tokens = 0; + t_compute_start_us = 0; +} + +const llama_model & llama_context::get_model() const { + return model; +} + +uint32_t llama_context::n_ctx() const { + return cparams.n_ctx; +} + +uint32_t llama_context::n_ctx_per_seq() const { + return cparams.n_ctx / cparams.n_seq_max; +} + +uint32_t llama_context::n_batch() const { + return cparams.n_batch; +} + +uint32_t llama_context::n_ubatch() const { + return cparams.n_ubatch; +} + +uint32_t llama_context::n_seq_max() const { + return cparams.n_seq_max; +} + +uint32_t llama_context::n_threads() const { + return cparams.n_threads; +} + +uint32_t llama_context::n_threads_batch() const { + return cparams.n_threads_batch; +} + +llama_kv_cache * llama_context::get_kv_self() { + return kv_self.get(); +} + +const llama_kv_cache * llama_context::get_kv_self() const { + return kv_self.get(); +} + +ggml_tensor * llama_context::build_rope_shift( + ggml_context * ctx0, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale, + ggml_backend_buffer * bbuf) const { + const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; + + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_attn_factor = cparams.yarn_attn_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; + + const auto & hparams = model.hparams; + + const auto & n_rot = hparams.n_rot; + const auto & rope_type = hparams.rope_type; + + ggml_tensor * tmp; + + if (ggml_is_quantized(cur->type)) { + // dequantize to f32 -> RoPE -> quantize back + tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32); + + if (bbuf) { + for (const auto & backend : backends) { + // Figure out which backend KV cache belongs to + if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) { + ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get()); + break; } - } else if (ubatch.output) { - int32_t n_outputs = 0; - for (int i = 0; i < n_tokens; ++i) { - if (ubatch.output[i]) { - data[n_outputs++] = i; - } + } + } + + tmp = ggml_rope_ext_inplace(ctx0, tmp, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + + tmp = ggml_cpy(ctx0, tmp, cur); + } else { + // we rotate only the first n_rot dimensions + tmp = ggml_rope_ext_inplace(ctx0, cur, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + } + + return tmp; +} + +class llm_graph_input_k_shift : public llm_graph_input_i { +public: + llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_k_shift() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * k_shift; // I32 [kv_size] + + const llama_kv_cache_unified * kv_self; +}; + +void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (k_shift) { + assert(ggml_backend_buffer_is_host(k_shift->buffer)); + + int32_t * data = (int32_t *) k_shift->data; + + for (uint32_t i = 0; i < kv_self->size; ++i) { + data[i] = kv_self->cells[i].delta; + } + } +} + +llm_graph_result_ptr llama_context::build_kv_self_shift( + ggml_context * ctx0, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & hparams = model.hparams; + + const auto & n_layer = hparams.n_layer; + + const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; + + //GGML_ASSERT(kv_self->size == n_ctx); + + auto inp = std::make_unique(kv_self.get()); + + inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx); + ggml_set_input(inp->k_shift); + + for (uint32_t il = 0; il < n_layer; ++il) { + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + + const bool is_swa = hparams.is_swa(il); + + // note: the swa rope params could become part of the cparams in the future + // if we decide to make them configurable, like the non-sliding ones + const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; + const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; + + ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il); + + ggml_tensor * k = + ggml_view_3d(ctx0, kv_self->k_l[il], + n_embd_head_k, n_head_kv, kv_self->size, + ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), + 0); + + ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer); + + ggml_build_forward_expand(gf, cur); + } + + res->add_input(std::move(inp)); + + return res; +} + +llm_graph_result_ptr llama_context::build_kv_self_defrag( + ggml_context * ctx0, + ggml_cgraph * gf, + const std::vector & moves) const { + auto res = std::make_unique(); + + const auto & hparams = model.hparams; + +#if 0 + // CPU defrag + // + // TODO: optimizations are possible: + // - multiple threads + // - avoid copying to the host memory when already there + // + // likely not worth the effort, as we have ggml_graph based defrag + // + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + const uint32_t kv_size = size; + + std::vector buf_k; + std::vector buf_v; + + for (uint32_t il = 0; il < n_layer; ++il) { + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); + + const size_t v_size_el = ggml_type_size(v_l[il]->type); + const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); + + buf_k.resize(k_size); + buf_v.resize(v_size); + + ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); + + // batch move [i, i+nm) to [id, id+nm) + // note: cells can move only to a lower index + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == n_kv) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < n_kv && ids[i + nm] == id + nm) { + nm++; + } + + // move keys + { + const int64_t os = i*k_size_row; + const int64_t od = id*k_size_row; + + memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); + } + + // move values (note: they are transposed) + { + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); } - // the graph needs to have been passed the correct number of outputs - GGML_ASSERT(lctx.n_outputs == n_outputs); - } else if (lctx.n_outputs == 1) { - // only keep last output - data[0] = n_tokens - 1; + } + + i += nm - 1; + } + + ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); + } +#else + for (const auto & move : moves) { + for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il], + n_embd_k_gqa, move.len, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*move.src)); + + ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il], + n_embd_k_gqa, move.len, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*move.dst)); + + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], + n_embd_v_gqa, move.len, + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*move.src)); + + view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], + n_embd_v_gqa, move.len, + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*move.dst)); } else { - GGML_ASSERT(lctx.n_outputs == 0); + view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], + move.len, n_embd_v_gqa, + ggml_row_size(kv_self->v_l[il]->type, kv_self->size), + ggml_row_size(kv_self->v_l[il]->type, move.src)); + + view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], + move.len, n_embd_v_gqa, + ggml_row_size(kv_self->v_l[il]->type, kv_self->size), + ggml_row_size(kv_self->v_l[il]->type, move.dst)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); + } + } +#endif + + return res; +} + +void llama_context::kv_self_update() { + auto & kv = kv_self; + + if (kv->has_shift) { + if (!kv->get_can_shift()) { + GGML_ABORT("The current context does not support K-shift"); + } + + LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); + + // apply K-shift if needed + if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched.get()); + + auto * gf = graph_init(); + + auto res = build_kv_self_shift(ctx_compute.get(), gf); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + res->set_inputs(nullptr); + + graph_compute(gf, false); + } + + { + kv->has_shift = false; + + for (uint32_t i = 0; i < kv->size; ++i) { + kv->cells[i].delta = 0; } } } - GGML_ASSERT( - // (!a || b) is a logical implication (a -> b) - // !hparams.causal_attn -> !cparams.causal_attn - (hparams.causal_attn || !cparams.causal_attn) && - "causal attention is not supported by this model" - ); + // defragment the KV cache if needed + if (kv->do_defrag) { + LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + const uint32_t n_max_nodes = graph_max_nodes(); + const uint32_t max_moves = (n_max_nodes - 2*model.hparams.n_layer)/(6*model.hparams.n_layer); + if (!kv->defrag_prepare(n_max_nodes)) { + LLAMA_LOG_ERROR("%s: failed to prepare defragmentation\n", __func__); + return; + } - if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { - // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. - if (cparams.causal_attn && !lctx.is_encoding) { - const int64_t n_kv = kv_self.n; - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; + for (std::size_t i = 0; i < kv_self->defrag_info.moves.size(); i += max_moves) { + std::vector chunk; + auto end = std::min(i + max_moves, kv_self->defrag_info.moves.size()); + chunk.assign(kv_self->defrag_info.moves.begin() + i, kv_self->defrag_info.moves.begin() + end); + ggml_backend_sched_reset(sched.get()); + auto * gf = graph_init(); + auto res = build_kv_self_defrag(ctx_compute.get(), gf, chunk); + ggml_backend_sched_alloc_graph(sched.get(), gf); + res->set_inputs(nullptr); + graph_compute(gf, false); + } - float * data = nullptr; - float * data_swa = nullptr; + kv->do_defrag = false; + } +} - if (lctx.inp_KQ_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - data = (float *) lctx.inp_KQ_mask->data; - } - - if (lctx.inp_KQ_mask_swa) { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); - data_swa = (float *) lctx.inp_KQ_mask_swa->data; - } - - // For causal attention, use only the previous KV cells - // of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - for (int h = 0; h < 1; ++h) { - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const llama_pos pos = ubatch.pos[s*n_seq_tokens + j]; - - for (int i = 0; i < n_kv; ++i) { - float f; - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -std::abs(kv_self.cells[i].pos - pos); - } else { - f = 0.0f; - } - } - - if (data) { - data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - - // may need to cut off old tokens for sliding window - if (data_swa) { - if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; - } - data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - } - } - } - - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } - - if (data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } +enum llama_pooling_type llama_context::pooling_type() const { + return cparams.pooling_type; +} + +float * llama_context::get_logits() { + // reorder logits for backward compatibility + output_reorder(); + + return logits; +} + +float * llama_context::get_logits_ith(int32_t i) { + int32_t j = -1; + + try { + if (logits == nullptr) { + throw std::runtime_error("no logits"); + } + + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); } else { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - // when using kv cache, the mask needs to match the kv cache size - const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - - float * data = (float *) lctx.inp_KQ_mask->data; - - for (int h = 0; h < 1; ++h) { - for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = ubatch.seq_id[s1][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const int32_t tj = s1*n_seq_tokens + j; - - for (int s0 = 0; s0 < n_seqs; ++s0) { - for (int i = 0; i < n_seq_tokens; ++i) { - const int32_t ti = s0*n_seq_tokens + i; - float f = -INFINITY; - - for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { - if (ubatch.seq_id[s0][s] == seq_id) { - if (hparams.use_alibi) { - f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); - } else { - f = 0.0f; - } - break; - } - } - - data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; - } - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; - } - } - } - } + j = output_ids[i]; } + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if (j >= n_outputs) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); + } + + return logits + j*model.hparams.n_vocab; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif } +} - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; +float * llama_context::get_embeddings() { + // reorder embeddings for backward compatibility + output_reorder(); - GGML_ASSERT(lctx.inp_mean); - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); + return embd; +} - float * data = (float *) lctx.inp_mean->data; - memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); +float * llama_context::get_embeddings_ith(int32_t i) { + int32_t j = -1; - std::vector sum(n_tokens, 0); - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); - - sum[seq_id] += ubatch.n_seq_tokens; + try { + if (embd == nullptr) { + throw std::runtime_error("no embeddings"); } - std::vector div(n_tokens, 0.0f); - for (int i = 0; i < n_tokens; ++i) { - const uint64_t s = sum[i]; - if (s > 0) { - div[i] = 1.0f/float(s); - } - } - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - for (int i = 0; i < n_seq_tokens; ++i) { - data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; - } - } - } - - if (cparams.embeddings && ( - cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || - cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - GGML_ASSERT(lctx.inp_cls); - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); - - uint32_t * data = (uint32_t *) lctx.inp_cls->data; - memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); - - for (int i = 0; i < n_seq_tokens; ++i) { - const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; - - if (pos == 0) { - data[seq_id] = s*n_seq_tokens + i; - } - } - } - } - - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - GGML_ASSERT(lctx.inp_cls); - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); - - uint32_t * data = (uint32_t *) lctx.inp_cls->data; - memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); - - std::vector last_pos(n_tokens, -1); - std::vector last_row(n_tokens, -1); - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); - - for (int i = 0; i < n_seq_tokens; ++i) { - const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; - - if (pos >= last_pos[seq_id]) { - last_pos[seq_id] = pos; - last_row[seq_id] = s*n_seq_tokens + i; - } - } - } - - for (int i = 0; i < n_tokens; ++i) { - if (last_row[i] >= 0) { - data[i] = last_row[i]; - } - } - } - - if (kv_self.recurrent) { - const int64_t n_kv = kv_self.n; - - if (lctx.inp_s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); - float * data = (float *) lctx.inp_s_mask->data; - - // clear unused states - for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - - data[i] = (float) (kv_cell.src >= 0); - - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } - } - } - - if (lctx.inp_s_copy) { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { - kv_cell.src = cell_id; - } - - data[i] = kv_cell.src; - - // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; - } - } - } - } - - if (lctx.inp_pos_bucket) { - const int64_t n_tokens = ubatch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); - GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing - - int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; - - if (!lctx.is_encoding) { - const int64_t n_kv = kv_self.n; - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_kv; ++i) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); - } - } + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); } else { - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_tokens; ++i) { - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); - } - } - } + j = output_ids[i]; } + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if (j >= n_outputs) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); + } + + return embd + j*model.hparams.n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + +float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { + auto it = embd_seq.find(seq_id); + if (it == embd_seq.end()) { + return nullptr; } - if (!lctx.is_encoding && lctx.inp_embd_enc) { - assert(lctx.inp_embd_enc->type == GGML_TYPE_F32); - assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size()); + return it->second.data(); +} - ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc)); - } +void llama_context::attach_threadpool( + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch) { + LLAMA_LOG_DEBUG("%s: call\n", __func__); - if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) { - const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd; - const int64_t n_tokens = ubatch.n_tokens; + this->threadpool = threadpool; + this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; +} - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer)); - GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing +void llama_context::detach_threadpool() { + LLAMA_LOG_DEBUG("%s: call\n", __func__); - float * data = (float *) lctx.inp_KQ_mask_cross->data; + this->threadpool = nullptr; + this->threadpool_batch = nullptr; +} - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_output_enc; ++i) { - float f = -INFINITY; - for (int s = 0; s < ubatch.n_seq_id[j]; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[j][s]; - if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) { - f = 0.0f; - } - } - data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f; - } - } +void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { + LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch); - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_output_enc; ++j) { - data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY; - } - } + cparams.n_threads = n_threads; + cparams.n_threads_batch = n_threads_batch; +} + +void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { + LLAMA_LOG_DEBUG("%s: call\n", __func__); + + this->abort_callback = abort_callback; + this->abort_callback_data = abort_callback_data; + + for (auto & backend : backends) { + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); } } } -// llama output +void llama_context::set_embeddings(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); -size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { - const auto & cparams = lctx.cparams; - const auto & hparams = lctx.model.hparams; + cparams.embeddings = value; +} - const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max); +void llama_context::set_causal_attn(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.causal_attn = value; +} + +void llama_context::set_warmup(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.warmup = value; +} + +void llama_context::set_cross_attn(bool value) { + cparams.cross_attn = value; +} + +void llama_context::set_adapter_lora( + llama_adapter_lora * adapter, + float scale) { + LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); + + loras[adapter] = scale; +} + +bool llama_context::rm_adapter_lora( + llama_adapter_lora * adapter) { + LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); + + auto pos = loras.find(adapter); + if (pos != loras.end()) { + loras.erase(pos); + return true; + } + + return false; +} + +void llama_context::clear_adapter_lora() { + LLAMA_LOG_DEBUG("%s: call\n", __func__); + + loras.clear(); +} + +bool llama_context::apply_adapter_cvec( + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); + + return cvec.apply(model, data, len, n_embd, il_start, il_end); +} + +int llama_context::encode(llama_batch & inp_batch) { + if (inp_batch.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // temporary allocate memory for the input batch if needed + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + + const llama_batch & batch = batch_allocr.batch; + const int32_t n_tokens = batch.n_tokens; + + const auto & hparams = model.hparams; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + if (batch.token) { + for (int32_t i = 0; i < n_tokens; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + return -1; + } + } + } + + // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot + GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + + n_queued_tokens += n_tokens; + + const int64_t n_embd = hparams.n_embd; + + sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true); + + const llama_ubatch ubatch = sbatch.split_simple(n_tokens); + + // reserve output buffer + if (output_reserve(n_tokens) < n_tokens) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); + return -2; + }; + + for (int32_t i = 0; i < n_tokens; ++i) { + output_ids[i] = i; + } + + n_outputs = n_tokens; + + //batch_manager->prepare(ubatch); + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + const auto causal_attn_org = cparams.causal_attn; + + // always use non-causal attention for encoder graphs + // TODO: this is a tmp solution until we have a proper way to support enc-dec models + // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 + cparams.causal_attn = false; + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + res->set_inputs(&ubatch); + + cparams.causal_attn = causal_attn_org; + + const auto compute_status = graph_compute(gf, n_tokens > 1); + switch (compute_status) { + case GGML_STATUS_SUCCESS: + break; + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } + + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + + // extract embeddings + if (t_embd) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + GGML_ASSERT(embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings + auto & embd_seq_out = embd_seq; + embd_seq_out.clear(); + + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits + + for (int32_t i = 0; i < n_tokens; i++) { + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // TODO: this likely should be the same logic as in llama_decoder_internal, but better to + // wait for an encoder model that requires this pooling type in order to test it + // https://github.com/ggerganov/llama.cpp/pull/9510 + GGML_ABORT("RANK pooling not implemented yet"); + } + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + + // TODO: hacky solution + if (model.arch == LLM_ARCH_T5 && t_embd) { + //cross.t_embd = t_embd; + + synchronize(); + + cross.n_embd = t_embd->ne[0]; + cross.n_enc = t_embd->ne[1]; + cross.v_embd.resize(cross.n_embd*cross.n_enc); + memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); + + // remember the sequence ids used during the encoding - needed for cross attention later + cross.seq_ids_enc.resize(n_tokens); + for (int32_t i = 0; i < n_tokens; i++) { + cross.seq_ids_enc[i].clear(); + for (int s = 0; s < ubatch.n_seq_id[i]; s++) { + llama_seq_id seq_id = ubatch.seq_id[i][s]; + cross.seq_ids_enc[i].insert(seq_id); + } + } + } + + return 0; +} + +int llama_context::decode(llama_batch & inp_batch) { + if (inp_batch.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // temporary allocate memory for the input batch if needed + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + + const llama_batch & batch = batch_allocr.batch; + + const auto & hparams = model.hparams; + + const int32_t n_vocab = hparams.n_vocab; + + const int64_t n_tokens_all = batch.n_tokens; + const int64_t n_embd = hparams.n_embd; + + llama_kv_cache_guard kv_guard(kv_self.get()); + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + if (batch.token) { + for (int64_t i = 0; i < n_tokens_all; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); + throw std::runtime_error("invalid token"); + } + } + } + + GGML_ASSERT(n_tokens_all <= cparams.n_batch); + + GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + n_queued_tokens += n_tokens_all; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + embd_seq.clear(); + + int64_t n_outputs_all = 0; + + // count outputs + if (batch.logits) { + for (uint32_t i = 0; i < n_tokens_all; ++i) { + n_outputs_all += batch.logits[i] != 0; + } + } else if (logits_all || embd_pooled) { + n_outputs_all = n_tokens_all; + } else { + // keep last output only + n_outputs_all = 1; + } + + const bool logits_all = n_outputs_all == n_tokens_all; + + sbatch.from_batch(batch, batch.n_embd, + /* simple_split */ !kv_self->recurrent, + /* logits_all */ logits_all); + + // reserve output buffer + if (output_reserve(n_outputs_all) < n_outputs_all) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); + return -2; + }; + + // handle any pending defrags/shifts + kv_self_update(); + + int64_t n_outputs_prev = 0; + + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch = llama_ubatch(); + + const auto & n_ubatch = cparams.n_ubatch; + + if (kv_self->recurrent) { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(cparams.n_ubatch); + } else { + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = sbatch.split_equal(cparams.n_ubatch); + } + } else { + ubatch = sbatch.split_simple(n_ubatch); + } + + // count the outputs in this u_batch + { + int32_t n_outputs_new = 0; + + if (n_outputs_all == n_tokens_all) { + n_outputs_new = ubatch.n_tokens; + } else { + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + n_outputs_new += (int32_t) (ubatch.output[i] != 0); + } + } + + // needs to happen before the graph is built + n_outputs = n_outputs_new; + } + + // find KV slot + { + if (!kv_self->find_slot(ubatch)) { + kv_self->defrag(); + kv_self_update(); + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); + return 1; + } + } + + if (!kv_self->recurrent) { + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + const uint32_t pad = kv_self->get_padding(cparams); + kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad))); + } + } + + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head); + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER); + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + res->set_inputs(&ubatch); + + const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); + if (compute_status != GGML_STATUS_SUCCESS) { + switch (compute_status) { + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } + } + + // plot the computation graph in dot format (for debugging purposes) + //if (n_past%100 == 0) { + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + //} + + auto * t_logits = cparams.causal_attn ? res->get_logits() : nullptr; + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + + if (t_embd && res->get_embd_pooled()) { + t_embd = res->get_embd_pooled(); + } + + // extract logits + if (t_logits && n_outputs > 0) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + float * logits_out = logits + n_outputs_prev*n_vocab; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } + } + + // extract embeddings + if (t_embd && n_outputs > 0) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - a single float per sequence + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + n_outputs_prev += n_outputs; + } + + // finalize the batch processing + kv_guard.commit(); + + // set output mappings + { + bool sorted_output = true; + + GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all); + + for (int64_t i = 0; i < n_outputs_all; ++i) { + int64_t out_id = sbatch.out_ids[i]; + output_ids[out_id] = i; + if (out_id != i) { + sorted_output = false; + } + } + + if (sorted_output) { + sbatch.out_ids.clear(); + } + } + + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + + // wait for the computation to finish (automatically done when obtaining the model output) + //synchronize(); + + // decide if we need to defrag the kv cache + if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > cparams.defrag_thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + kv_self->defrag(); + } + } + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + + return 0; +} + +// +// output +// + +int32_t llama_context::output_reserve(int32_t n_outputs) { + const auto & hparams = model.hparams; + + const int64_t n_outputs_max = std::max(n_outputs, n_seq_max()); const auto n_batch = cparams.n_batch; const auto n_vocab = hparams.n_vocab; const auto n_embd = hparams.n_embd; // TODO: use a per-batch flag for logits presence instead - const bool has_logits = cparams.causal_attn; - const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + bool has_logits = cparams.causal_attn; + bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); - const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; - const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; - - if (lctx.output_ids.empty()) { - // init, never resized afterwards - lctx.output_ids.resize(n_batch); + // TODO: hacky enc-dec support + if (model.arch == LLM_ARCH_T5) { + has_logits = true; + has_embd = true; } - const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0; + logits_size = has_logits ? n_vocab*n_outputs_max : 0; + embd_size = has_embd ? n_embd*n_outputs_max : 0; + + if (output_ids.empty()) { + // init, never resized afterwards + output_ids.resize(n_batch); + } + + const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = (logits_size + embd_size) * sizeof(float); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer - if (!lctx.buf_output || prev_size < new_size) { - if (lctx.buf_output) { + if (!buf_output || prev_size < new_size) { + if (buf_output) { #ifndef NDEBUG // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif - lctx.buf_output = nullptr; - lctx.logits = nullptr; - lctx.embd = nullptr; + buf_output = nullptr; + logits = nullptr; + embd = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory - auto * output_dev = lctx.model.dev_output(); + auto * output_dev = model.dev_output(); auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; if (output_dev_host_buft) { buft = output_dev_host_buft; } - lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); - if (lctx.buf_output == nullptr) { + buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); + if (buf_output == nullptr) { LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); return 0; } } - float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get()); + float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - lctx.logits = has_logits ? output_base : nullptr; - lctx.embd = has_embd ? output_base + logits_size : nullptr; - - lctx.output_size = n_outputs_max; - lctx.logits_size = logits_size; - lctx.embd_size = embd_size; + logits = has_logits ? output_base : nullptr; + embd = has_embd ? output_base + logits_size : nullptr; // set all ids as invalid (negative) - std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1); + std::fill(output_ids.begin(), output_ids.end(), -1); - ggml_backend_buffer_clear(lctx.buf_output.get(), 0); + ggml_backend_buffer_clear(buf_output.get(), 0); - lctx.n_outputs = 0; + this->n_outputs = 0; + this->n_outputs_max = n_outputs_max; return n_outputs_max; } -void llama_output_reorder(struct llama_context & ctx) { - std::vector & out_ids = ctx.sbatch.out_ids; +void llama_context::output_reorder() { + auto & out_ids = sbatch.out_ids; if (!out_ids.empty()) { - const uint32_t n_vocab = ctx.model.hparams.n_vocab; - const uint32_t n_embd = ctx.model.hparams.n_embd; + const uint32_t n_vocab = model.hparams.n_vocab; + const uint32_t n_embd = model.hparams.n_embd; - const int32_t n_outputs = ctx.n_outputs; GGML_ASSERT((size_t) n_outputs == out_ids.size()); // TODO: is there something more efficient which also minimizes swaps? @@ -567,846 +1533,156 @@ void llama_output_reorder(struct llama_context & ctx) { } if (j_min == i) { continue; } std::swap(out_ids[i], out_ids[j_min]); - if (ctx.logits_size > 0) { + if (logits_size > 0) { for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]); + std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); } } - if (ctx.embd_size > 0) { + if (embd_size > 0) { for (uint32_t k = 0; k < n_embd; k++) { - std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]); + std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); } } } - std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1); + std::fill(output_ids.begin(), output_ids.end(), -1); for (int32_t i = 0; i < n_outputs; ++i) { - ctx.output_ids[out_ids[i]] = i; + output_ids[out_ids[i]] = i; } out_ids.clear(); } } // -// interface implementation +// graph // -void llama_free(struct llama_context * ctx) { - delete ctx; +int32_t llama_context::graph_max_nodes() const { + return std::max(65536, 5*model.n_tensors()); } -uint32_t llama_n_ctx(const struct llama_context * ctx) { - return ctx->cparams.n_ctx; +ggml_cgraph * llama_context::graph_init() { + ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ctx_compute.reset(ggml_init(params)); + + return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } -uint32_t llama_n_batch(const struct llama_context * ctx) { - return ctx->cparams.n_batch; +llm_graph_result_ptr llama_context::graph_build( + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype) { + return model.build_graph( + { + /*.ctx =*/ ctx, + /*.arch =*/ model.arch, + /*.hparams =*/ model.hparams, + /*.cparams =*/ cparams, + /*.ubatch =*/ ubatch, + /*.sched =*/ sched.get(), + /*.backend_cpu =*/ backend_cpu, + /*.cvec =*/ &cvec, + /*.loras =*/ &loras, + /*.memory =*/ kv_self.get(), + /*.cross =*/ &cross, + /*.n_outputs =*/ n_outputs, + /*.cb =*/ graph_get_cb(), + }, gf, gtype); } -uint32_t llama_n_ubatch(const struct llama_context * ctx) { - return ctx->cparams.n_ubatch; -} +ggml_status llama_context::graph_compute( + ggml_cgraph * gf, + bool batched) { + int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; + ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; -uint32_t llama_n_seq_max(const struct llama_context * ctx) { - return ctx->kv_self.size; -} - -const struct llama_model * llama_get_model(const struct llama_context * ctx) { - return &ctx->model; -} - -enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { - return ctx->cparams.pooling_type; -} - -void llama_attach_threadpool( - struct llama_context * ctx, - ggml_threadpool_t threadpool, - ggml_threadpool_t threadpool_batch) { - ctx->threadpool = threadpool; - ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; -} - -void llama_detach_threadpool(struct llama_context * ctx) { - ctx->threadpool = nullptr; - ctx->threadpool_batch = nullptr; -} - -void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { - ctx->cparams.n_threads = n_threads; - ctx->cparams.n_threads_batch = n_threads_batch; -} - -int32_t llama_n_threads(struct llama_context * ctx) { - return ctx->cparams.n_threads; -} - -int32_t llama_n_threads_batch(struct llama_context * ctx) { - return ctx->cparams.n_threads_batch; -} - -void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { - ctx->abort_callback = abort_callback; - ctx->abort_callback_data = abort_callback_data; - - for (auto & backend : ctx->backends) { - auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); - auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); - if (set_abort_callback_fn) { - set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data); - } - } -} - -void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { - ctx->cparams.embeddings = embeddings; -} - -void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { - ctx->cparams.causal_attn = causal_attn; -} - -void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) { - ctx->cparams.cross_attn = cross_attention; -} - -void llama_synchronize(struct llama_context * ctx) { - ggml_backend_sched_synchronize(ctx->sched.get()); - - // FIXME: if multiple single tokens are evaluated without a synchronization, - // the stats will be added to the prompt evaluation stats - // this should only happen when using batch size 1 to evaluate a batch - - // add the evaluation to the stats - if (ctx->n_queued_tokens == 1) { - if (!ctx->cparams.no_perf) { - ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us; - } - ctx->n_eval++; - } else if (ctx->n_queued_tokens > 1) { - if (!ctx->cparams.no_perf) { - ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us; - } - ctx->n_p_eval += ctx->n_queued_tokens; + if (backend_cpu != nullptr) { + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); + auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); + set_threadpool_fn(backend_cpu, tp); } - // get a more accurate load time, upon first eval - if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) { - ctx->t_load_us = ggml_time_us() - ctx->t_start_us; - ctx->has_evaluated_once = true; + // set the number of threads for all the backends + for (const auto & set_n_threads_fn : set_n_threads_fns) { + set_n_threads_fn.second(set_n_threads_fn.first, n_threads); } - ctx->n_queued_tokens = 0; - ctx->t_compute_start_us = 0; + auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); + } + + // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); + + return status; } -float * llama_get_logits(struct llama_context * ctx) { - llama_synchronize(ctx); - - // reorder logits for backward compatibility - // TODO: maybe deprecate this - llama_output_reorder(*ctx); - - return ctx->logits; -} - -float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { - int32_t j = -1; - - llama_synchronize(ctx); - - try { - if (ctx->logits == nullptr) { - throw std::runtime_error("no logits"); - } - - if (i < 0) { - j = ctx->n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); - } - } else if ((size_t) i >= ctx->output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); +llm_graph_cb llama_context::graph_get_cb() const { + return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { + if (il >= 0) { + ggml_format_name(cur, "%s-%d", name, il); } else { - j = ctx->output_ids[i]; + ggml_set_name(cur, name); } - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= ctx->n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); - } - - return ctx->logits + j*ctx->model.hparams.n_vocab; - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); -#ifndef NDEBUG - GGML_ABORT("fatal error"); -#else - return nullptr; -#endif - } -} - -float * llama_get_embeddings(struct llama_context * ctx) { - llama_synchronize(ctx); - - // reorder embeddings for backward compatibility - // TODO: maybe deprecate this - llama_output_reorder(*ctx); - - return ctx->embd; -} - -float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { - int32_t j = -1; - - llama_synchronize(ctx); - - try { - if (ctx->embd == nullptr) { - throw std::runtime_error("no embeddings"); - } - - if (i < 0) { - j = ctx->n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); - } - } else if ((size_t) i >= ctx->output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); - } else { - j = ctx->output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= ctx->n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); - } - - return ctx->embd + j*ctx->model.hparams.n_embd; - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); -#ifndef NDEBUG - GGML_ABORT("fatal error"); -#else - return nullptr; -#endif - } -} - -float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { - llama_synchronize(ctx); - - auto it = ctx->embd_seq.find(seq_id); - if (it == ctx->embd_seq.end()) { - return nullptr; - } - - return it->second.data(); -} - -// llama state API - -// deprecated -size_t llama_get_state_size(struct llama_context * ctx) { - return llama_state_get_size(ctx); -} - -// deprecated -size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { - return llama_state_get_data(ctx, dst, -1); -} - -// deprecated -size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { - return llama_state_set_data(ctx, src, -1); -} - -// deprecated -bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { - return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); -} - -// deprecated -bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { - return llama_state_save_file(ctx, path_session, tokens, n_token_count); -} - -// TODO: replace all non-fatal assertions with returned errors or exceptions -struct llama_data_write { - virtual void write(const void * src, size_t size) = 0; - virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0; - virtual size_t get_size_written() = 0; - virtual ~llama_data_write() = default; - - void write_string(const std::string & str) { - uint32_t str_size = str.size(); - - write(&str_size, sizeof(str_size)); - write(str.data(), str_size); - } - - void write_model_info(const struct llama_context * ctx) { - const std::string arch_str = llm_arch_name(ctx->model.arch); - write_string(arch_str); - // TODO: add more model-specific info which should prevent loading the session file if not identical - } - - //void write_rng(const std::mt19937 & rng) { - // std::ostringstream rng_ss; - // rng_ss << rng; - - // const std::string & rng_str = rng_ss.str(); - - // write_string(rng_str); - //} - - void write_output_ids(struct llama_context * ctx) { - llama_output_reorder(*ctx); - - const uint32_t n_outputs = ctx->n_outputs; - - std::vector output_pos; - - const size_t n_batch = ctx->cparams.n_batch; - const auto & output_ids = ctx->output_ids; - - GGML_ASSERT(n_outputs <= ctx->output_size); - - output_pos.resize(n_outputs); - - // build a more compact representation of the output ids - for (size_t i = 0; i < n_batch; ++i) { - // map an output id to a position in the batch - int32_t pos = output_ids[i]; - if (pos >= 0) { - GGML_ASSERT((uint32_t) pos < n_outputs); - output_pos[pos] = i; + if (!cparams.offload_kqv) { + if (strcmp(name, "kqv_merged_cont") == 0) { + // all nodes between the KV store and the attention output are run on the CPU + ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); } } - write(&n_outputs, sizeof(n_outputs)); - - if (n_outputs) { - write(output_pos.data(), n_outputs * sizeof(int32_t)); - } - } - - void write_logits(const struct llama_context * ctx) { - const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab); - - write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - write(ctx->logits, logits_size * sizeof(float)); - } - } - - void write_embeddings(const struct llama_context * ctx) { - const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd); - - write(&embeddings_size, sizeof(embeddings_size)); - - if (embeddings_size) { - write(ctx->embd, embeddings_size * sizeof(float)); - } - } - - void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { - for (const auto & range : cell_ranges) { - for (uint32_t i = range.first; i < range.second; ++i) { - const auto & cell = kv_self.cells[i]; - const llama_pos pos = cell.pos; - const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; - - write(&pos, sizeof(pos)); - write(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id) { - for (auto seq_id : cell.seq_id) { - write(&seq_id, sizeof(seq_id)); - } - } - } - } - } - - void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { - const struct llama_kv_cache & kv_self = ctx->kv_self; - const struct llama_hparams & hparams = ctx->model.hparams; - - const uint32_t v_trans = kv_self.v_trans ? 1 : 0; - const uint32_t n_layer = hparams.n_layer; - - write(&v_trans, sizeof(v_trans)); - write(&n_layer, sizeof(n_layer)); - - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell - // Get whole range at a time - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Write key type - const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; - write(&k_type_i, sizeof(k_type_i)); - - // Write row size of key - const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); - write(&k_size_row, sizeof(k_size_row)); - - // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * k_size_row; - write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size); - } - } - - if (!kv_self.v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - write(&v_type_i, sizeof(v_type_i)); - - // Write row size of value - const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); - write(&v_size_row, sizeof(v_size_row)); - - // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * v_size_row; - write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size); - } - } - } else { - // When v is transposed, we also need the element size and get the element ranges from each row - const uint32_t kv_size = kv_self.size; - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - write(&v_type_i, sizeof(v_type_i)); - - // Write element size - const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - write(&v_size_el, sizeof(v_size_el)); - - // Write GQA embedding size - write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); - - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - const size_t buf_size = range_size * v_size_el; - write_tensor_data(kv_self.v_l[il], src_offset, buf_size); - } - } - } - } - } - - void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { - const struct llama_kv_cache & kv_self = ctx->kv_self; - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; - - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = kv_self.size; - for (uint32_t i = 0; i < kv_self.size; ++i) { - const auto & cell = kv_self.cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { - ++cell_count; - if (cell_range_begin == kv_self.size) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != kv_self.size) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = kv_self.size; - } - } - } - if (cell_range_begin != kv_self.size) { - cell_ranges.emplace_back(cell_range_begin, kv_self.size); - } - - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); - - write(&cell_count, sizeof(cell_count)); - - write_kv_cache_meta(kv_self, cell_ranges, seq_id); - write_kv_cache_data(ctx, cell_ranges); - } -}; - -struct llama_data_read { - virtual const uint8_t * read(size_t size) = 0; - virtual void read_to(void * dst, size_t size) = 0; - virtual size_t get_size_read() = 0; - virtual ~llama_data_read() = default; - - void read_string(std::string & str) { - uint32_t str_size; - read_to(&str_size, sizeof(str_size)); - - str.assign((const char *) read(str_size), str_size); - } - - // validate model information - void read_model_info(const struct llama_context * ctx) { - const std::string cur_arch_str = llm_arch_name(ctx->model.arch); - - std::string arch_str; - read_string(arch_str); - if (cur_arch_str != arch_str) { - throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); - } - // TODO: add more info which needs to be identical but which is not verified otherwise - } - - //void read_rng(std::mt19937 & rng) { - // std::string rng_str; - // read_string(rng_str); - - // std::istringstream rng_ss(rng_str); - // rng_ss >> rng; - - // if (rng_ss.fail()) { - // throw std::runtime_error("failed to load RNG state"); - // } - //} - - void read_output_ids(struct llama_context * ctx) { - std::vector output_pos; - - uint32_t n_outputs; - read_to(&n_outputs, sizeof(n_outputs)); - - if (n_outputs > llama_output_reserve(*ctx, n_outputs)) { - throw std::runtime_error("could not reserve outputs"); - } - - if (n_outputs) { - output_pos.resize(n_outputs); - read_to(output_pos.data(), n_outputs * sizeof(int32_t)); - - for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { - int32_t id = output_pos[i]; - if ((uint32_t) id >= ctx->cparams.n_batch) { - throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch)); - } - ctx->output_ids[id] = i; - } - - ctx->n_outputs = n_outputs; - } - } - - void read_logits(struct llama_context * ctx) { - uint64_t logits_size; - read_to(&logits_size, sizeof(logits_size)); - - if (ctx->logits_size < logits_size) { - throw std::runtime_error("logits buffer too small"); - } - - if (logits_size) { - read_to(ctx->logits, logits_size * sizeof(float)); - } - } - - void read_embeddings(struct llama_context * ctx) { - uint64_t embeddings_size; - read_to(&embeddings_size, sizeof(embeddings_size)); - - if (ctx->embd_size < embeddings_size) { - throw std::runtime_error("embeddings buffer too small"); - } - - if (embeddings_size) { - read_to(ctx->embd, embeddings_size * sizeof(float)); - } - } - - bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) { - struct llama_kv_cache & kv_self = ctx->kv_self; - - if (dest_seq_id != -1) { - // single sequence - - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - - llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - batch.n_tokens = cell_count; - batch.n_seq_tokens = cell_count; - batch.n_seqs = 1; - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; - - read_to(&pos, sizeof(pos)); - read_to(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id != 0) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); - return false; - } - - batch.pos[i] = pos; - } - batch.n_seq_id[0] = 1; - batch.seq_id[0] = &dest_seq_id; - if (!llama_kv_cache_find_slot(kv_self, batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; - } - - // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); - GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); - } else { - // whole KV cache restore - - if (cell_count > kv_self.size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } - - llama_kv_cache_clear(kv_self); - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_kv_cell & cell = kv_self.cells[i]; - - llama_pos pos; - uint32_t n_seq_id; - - read_to(&pos, sizeof(pos)); - read_to(&n_seq_id, sizeof(n_seq_id)); - - cell.pos = pos; - - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - read_to(&seq_id, sizeof(seq_id)); - - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - return false; - } - - cell.seq_id.insert(seq_id); - - if (kv_self.recurrent) { - int32_t & tail = kv_self.cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; + // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends + // FIXME: fix in ggml_backend_sched + const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer; + if (ubatch.n_tokens < 32 || full_offload) { + if (il != -1 && strcmp(name, "norm") == 0) { + const auto & dev_layer = model.dev_layer(il); + for (const auto & backend : backends) { + if (ggml_backend_get_device(backend.get()) == dev_layer) { + if (ggml_backend_supports_op(backend.get(), cur)) { + ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); } - tail = i; - } - } - } - - kv_self.head = 0; - kv_self.used = cell_count; - } - - if (kv_self.recurrent) { - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = kv_self.head + i; - // make sure the recurrent states will keep their restored state - kv_self.cells[cell_id].src = cell_id; - } - } - - return true; - } - - bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { - const struct llama_hparams & hparams = ctx->model.hparams; - struct llama_kv_cache & kv_self = ctx->kv_self; - uint32_t v_trans; - uint32_t n_layer; - read_to(&v_trans, sizeof(v_trans)); - read_to(&n_layer, sizeof(n_layer)); - - if (n_layer != hparams.n_layer) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); - return false; - } - if (cell_count > kv_self.size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size); - return false; - } - if (kv_self.v_trans != (bool) v_trans) { - LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); - return false; - } - - // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Read type of key - int32_t k_type_i_ref; - read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; - if (k_type_i != k_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); - return false; - } - - // Read row size of key - uint64_t k_size_row_ref; - read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); - if (k_size_row != k_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); - } - } - - if (!kv_self.v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read row size of value - uint64_t v_size_row_ref; - read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); - if (v_size_row != v_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); - } - } - } else { - // For each layer, read the values for each cell (transposed) - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read element size of value - uint32_t v_size_el_ref; - read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - if (v_size_el != v_size_el_ref) { - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); - return false; - } - - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); - return false; - } - - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; - ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } } } - return true; - } + }; +} - void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { - uint32_t cell_count; - read_to(&cell_count, sizeof(cell_count)); +// +// state save/load +// - bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); - - if (!res) { - if (seq_id == -1) { - llama_kv_cache_clear(ctx); - } else { - llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); - } - throw std::runtime_error("failed to restore kv cache"); - } - } -}; - -struct llama_data_write_dummy : llama_data_write { - size_t size_written = 0; - - llama_data_write_dummy() {} +class llama_io_write_dummy : public llama_io_write_i { +public: + llama_io_write_dummy() = default; void write(const void * /* src */, size_t size) override { size_written += size; } - void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { size_written += size; } - size_t get_size_written() override { + size_t n_bytes() override { return size_written; } + +private: + size_t size_written = 0; }; -struct llama_data_write_buffer : llama_data_write { - uint8_t * ptr; - size_t buf_size = 0; - size_t size_written = 0; - - llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {} +class llama_io_write_buffer : public llama_io_write_i { +public: + llama_io_write_buffer( + uint8_t * p, size_t len) : ptr(p), buf_size(len) {} void write(const void * src, size_t size) override { if (size > buf_size) { @@ -1418,7 +1694,7 @@ struct llama_data_write_buffer : llama_data_write { buf_size -= size; } - void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { + void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); } @@ -1428,17 +1704,19 @@ struct llama_data_write_buffer : llama_data_write { buf_size -= size; } - size_t get_size_written() override { + size_t n_bytes() override { return size_written; } + +private: + uint8_t * ptr; + size_t buf_size = 0; + size_t size_written = 0; }; -struct llama_data_read_buffer : llama_data_read { - const uint8_t * ptr; - size_t buf_size = 0; - size_t size_read = 0; - - llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} +class llama_io_read_buffer : public llama_io_read_i { +public: + llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} const uint8_t * read(size_t size) override { const uint8_t * base_ptr = ptr; @@ -1455,40 +1733,44 @@ struct llama_data_read_buffer : llama_data_read { memcpy(dst, read(size), size); } - size_t get_size_read() override { + size_t n_bytes() override { return size_read; } + +private: + const uint8_t * ptr; + size_t buf_size = 0; + size_t size_read = 0; }; -struct llama_data_write_file : llama_data_write { - llama_file * file; - size_t size_written = 0; - std::vector temp_buffer; - - llama_data_write_file(llama_file * f) : file(f) {} +class llama_io_write_file : public llama_io_write_i { +public: + llama_io_write_file(llama_file * f) : file(f) {} void write(const void * src, size_t size) override { file->write_raw(src, size); size_written += size; } - void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { + void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { temp_buffer.resize(size); ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); write(temp_buffer.data(), temp_buffer.size()); } - size_t get_size_written() override { + size_t n_bytes() override { return size_written; } + +private: + llama_file * file; + size_t size_written = 0; + std::vector temp_buffer; }; -struct llama_data_read_file : llama_data_read { - llama_file * file; - size_t size_read = 0; - std::vector temp_buffer; - - llama_data_read_file(llama_file * f) : file(f) {} +class llama_io_read_file : public llama_io_read_i { +public: + llama_io_read_file(llama_file * f) : file(f) {} void read_to(void * dst, size_t size) override { file->read_raw(dst, size); @@ -1501,89 +1783,78 @@ struct llama_data_read_file : llama_data_read { return temp_buffer.data(); } - size_t get_size_read() override { + size_t n_bytes() override { return size_read; } + +private: + llama_file * file; + size_t size_read = 0; + std::vector temp_buffer; }; -/** copy state data into either a buffer or file depending on the passed in context - * - * file context: - * llama_file file("/path", "wb"); - * llama_data_write_file data_ctx(&file); - * llama_state_get_data_internal(ctx, data_ctx); - * - * buffer context: - * std::vector buf(max_size, 0); - * llama_data_write_buffer data_ctx(buf.data(), max_size); - * llama_state_get_data_internal(ctx, data_ctx); - * -*/ -static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) { - llama_synchronize(ctx); - - data_ctx.write_model_info(ctx); - - // copy outputs - data_ctx.write_output_ids(ctx); - data_ctx.write_logits(ctx); - data_ctx.write_embeddings(ctx); - - data_ctx.write_kv_cache(ctx); - - return data_ctx.get_size_written(); -} - -size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) { - llama_data_write_buffer data_ctx(dst, size); +size_t llama_context::state_get_size() { + llama_io_write_dummy io; try { - return llama_state_get_data_internal(ctx, data_ctx); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); - return 0; - } -} - -// Returns the *actual* size of the state. -// Intended to be used when saving to state to a buffer. -size_t llama_state_get_size(struct llama_context * ctx) { - llama_data_write_dummy data_ctx; - try { - return llama_state_get_data_internal(ctx, data_ctx); + return state_write_data(io); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); return 0; } } -static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) { - llama_synchronize(ctx); - - data_ctx.read_model_info(ctx); - - // set outputs - data_ctx.read_output_ids(ctx); - data_ctx.read_logits(ctx); - data_ctx.read_embeddings(ctx); - - data_ctx.read_kv_cache(ctx); - - return data_ctx.get_size_read(); +size_t llama_context::state_get_data(uint8_t * dst, size_t size) { + llama_io_write_buffer io(dst, size); + try { + return state_write_data(io); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } } -// Sets the state reading from the specified source address -size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) { - llama_data_read_buffer data_ctx(src, size); +size_t llama_context::state_set_data(const uint8_t * src, size_t size) { + llama_io_read_buffer io(src, size); try { - return llama_state_set_data_internal(ctx, data_ctx); + return state_read_data(io); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; } } -static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { - llama_file file(path_session, "rb"); +size_t llama_context::state_seq_get_size(llama_seq_id seq_id) { + llama_io_write_dummy io; + try { + return state_seq_write_data(io, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { + llama_io_write_buffer io(dst, size); + try { + return state_seq_write_data(io, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { + llama_io_read_buffer io(src, size); + try { + return state_seq_read_data(io, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); + return 0; + } +} + +bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); // sanity checks { @@ -1613,28 +1884,20 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha { const size_t n_state_size_cur = file.size() - file.tell(); - llama_data_read_file data_ctx(&file); - const size_t n_read = llama_state_set_data_internal(ctx, data_ctx); + llama_io_read_file io( &file); + const size_t n_read = state_read_data(io); if (n_read != n_state_size_cur) { LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); return false; } } + return true; } -bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { - try { - return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); - return false; - } -} - -static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { - llama_file file(path_session, "wb"); +bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); file.write_u32(LLAMA_SESSION_MAGIC); file.write_u32(LLAMA_SESSION_VERSION); @@ -1644,82 +1907,13 @@ static bool llama_state_save_file_internal(struct llama_context * ctx, const cha file.write_raw(tokens, sizeof(llama_token) * n_token_count); // save the context state using stream saving - llama_data_write_file data_ctx(&file); - llama_state_get_data_internal(ctx, data_ctx); + llama_io_write_file io(&file); + state_write_data(io); return true; } -bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { - try { - return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); - return false; - } -} - -static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) { - llama_synchronize(ctx); - - data_ctx.write_kv_cache(ctx, seq_id); - - return data_ctx.get_size_written(); -} - -size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) { - llama_data_write_dummy data_ctx; - return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); -} - -size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { - llama_data_write_buffer data_ctx(dst, size); - try { - return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what()); - return 0; - } -} - -static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) { - llama_synchronize(ctx); - - data_ctx.read_kv_cache(ctx, dest_seq_id); - - return data_ctx.get_size_read(); -} - -size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) { - llama_data_read_buffer data_ctx(src, size); - try { - return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what()); - return 0; - } -} - -static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { - llama_file file(filepath, "wb"); - - file.write_u32(LLAMA_STATE_SEQ_MAGIC); - file.write_u32(LLAMA_STATE_SEQ_VERSION); - - // save the prompt - file.write_u32((uint32_t) n_token_count); - file.write_raw(tokens, sizeof(llama_token) * n_token_count); - - // save the context state using stream saving - llama_data_write_file data_ctx(&file); - llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); - - const size_t res = file.tell(); - GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); - return res; -} - -static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(filepath, "rb"); // version checks @@ -1749,8 +1943,8 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con // restore the context state { const size_t state_size = file.size() - file.tell(); - llama_data_read_file data_ctx(&file); - const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); + llama_io_read_file io(&file); + const size_t nread = state_seq_read_data(io, seq_id); if (!nread) { LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); return 0; @@ -1762,26 +1956,844 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con return file.tell(); } -size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { +size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_STATE_SEQ_MAGIC); + file.write_u32(LLAMA_STATE_SEQ_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_io_write_file io(&file); + state_seq_write_data(io, seq_id); + + const size_t res = file.tell(); + GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); + + return res; +} + +size_t llama_context::state_write_data(llama_io_write_i & io) { + LLAMA_LOG_DEBUG("%s: writing state\n", __func__); + + // write model info + { + LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__); + + const std::string arch_str = llm_arch_name(model.arch); + io.write_string(arch_str); + // TODO: add more model-specific info which should prevent loading the session file if not identical + } + + // write output ids + { + LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); + + output_reorder(); + + const auto n_outputs = this->n_outputs; + const auto & output_ids = this->output_ids; + + std::vector w_output_pos; + + GGML_ASSERT(n_outputs <= n_outputs_max); + + w_output_pos.resize(n_outputs); + + // build a more compact representation of the output ids + for (size_t i = 0; i < n_batch(); ++i) { + // map an output id to a position in the batch + int32_t pos = output_ids[i]; + if (pos >= 0) { + GGML_ASSERT(pos < n_outputs); + w_output_pos[pos] = i; + } + } + + io.write(&n_outputs, sizeof(n_outputs)); + + if (n_outputs) { + io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); + } + } + + // write logits + { + LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); + + const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.hparams.n_vocab); + + io.write(&logits_size, sizeof(logits_size)); + + if (logits_size) { + io.write(logits, logits_size * sizeof(float)); + } + } + + // write embeddings + { + LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); + + const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); + + io.write(&embd_size, sizeof(embd_size)); + + if (embd_size) { + io.write(embd, embd_size * sizeof(float)); + } + } + + LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + kv_self->state_write(io); + + return io.n_bytes(); +} + +size_t llama_context::state_read_data(llama_io_read_i & io) { + LLAMA_LOG_DEBUG("%s: reading state\n", __func__); + + // read model info + { + LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__); + + const std::string cur_arch_str = llm_arch_name(model.arch); + + std::string arch_str; + io.read_string(arch_str); + if (cur_arch_str != arch_str) { + throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); + } + // TODO: add more info which needs to be identical but which is not verified otherwise + } + + // read output ids + { + LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); + + auto n_outputs = this->n_outputs; + io.read_to(&n_outputs, sizeof(n_outputs)); + + if (n_outputs > output_reserve(n_outputs)) { + throw std::runtime_error("could not reserve outputs"); + } + + std::vector output_pos; + + if (n_outputs) { + output_pos.resize(n_outputs); + io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); + + for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { + int32_t id = output_pos[i]; + if ((uint32_t) id >= n_batch()) { + throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); + } + this->output_ids[id] = i; + } + + this->n_outputs = n_outputs; + } + } + + // read logits + { + LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); + + uint64_t logits_size; + io.read_to(&logits_size, sizeof(logits_size)); + + if (this->logits_size < logits_size) { + throw std::runtime_error("logits buffer too small"); + } + + if (logits_size) { + io.read_to(this->logits, logits_size * sizeof(float)); + } + } + + // read embeddings + { + LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); + + uint64_t embd_size; + io.read_to(&embd_size, sizeof(embd_size)); + + if (this->embd_size < embd_size) { + throw std::runtime_error("embeddings buffer too small"); + } + + if (embd_size) { + io.read_to(this->embd, embd_size * sizeof(float)); + } + } + + LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); + kv_self->state_read(io); + + return io.n_bytes(); +} + +size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { + GGML_UNUSED(seq_id); + + kv_self->state_write(io, seq_id); + + return io.n_bytes(); +} + +size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { + GGML_UNUSED(seq_id); + + kv_self->state_read(io, seq_id); + + return io.n_bytes(); +} + +// +// perf +// + +llama_perf_context_data llama_context::perf_get_data() const { + llama_perf_context_data data = {}; + + data.t_start_ms = 1e-3 * t_start_us; + data.t_load_ms = 1e-3 * t_load_us; + data.t_p_eval_ms = 1e-3 * t_p_eval_us; + data.t_eval_ms = 1e-3 * t_eval_us; + data.n_p_eval = std::max(1, n_p_eval); + data.n_eval = std::max(1, n_eval); + + return data; +} + +void llama_context::perf_reset() { + t_start_us = ggml_time_us(); + t_eval_us = n_eval = 0; + t_p_eval_us = n_p_eval = 0; +} + +// +// interface implementation +// + +llama_context_params llama_context_default_params() { + llama_context_params result = { + /*.n_ctx =*/ 512, + /*.n_batch =*/ 2048, + /*.n_ubatch =*/ 512, + /*.n_seq_max =*/ 1, + /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default + /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, + /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, + /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, + /*.rope_freq_base =*/ 0.0f, + /*.rope_freq_scale =*/ 0.0f, + /*.yarn_ext_factor =*/ -1.0f, + /*.yarn_attn_factor =*/ 1.0f, + /*.yarn_beta_fast =*/ 32.0f, + /*.yarn_beta_slow =*/ 1.0f, + /*.yarn_orig_ctx =*/ 0, + /*.defrag_thold =*/ -1.0f, + /*.cb_eval =*/ nullptr, + /*.cb_eval_user_data =*/ nullptr, + /*.type_k =*/ GGML_TYPE_F16, + /*.type_v =*/ GGML_TYPE_F16, + /*.logits_all =*/ false, + /*.embeddings =*/ false, + /*.offload_kqv =*/ true, + /*.flash_attn =*/ false, + /*.no_perf =*/ true, + /*.cross_attn =*/ false, + /*.abort_callback =*/ nullptr, + /*.abort_callback_data =*/ nullptr, + }; + + return result; +} + +llama_context * llama_init_from_model( + llama_model * model, + llama_context_params params) { + if (!model) { + LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__); + return nullptr; + } + + if (params.n_batch == 0 && params.n_ubatch == 0) { + LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); + return nullptr; + } + + if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) { + LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__); + return nullptr; + } + + if (params.flash_attn && model->arch == LLM_ARCH_GROK) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); + params.flash_attn = false; + } + + if (ggml_is_quantized(params.type_v) && !params.flash_attn) { + LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); + return nullptr; + } + try { - return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count); + auto * ctx = new llama_context(*model, params); + return ctx; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what()); + } + + return nullptr; +} + +// deprecated +llama_context * llama_new_context_with_model( + llama_model * model, + llama_context_params params) { + return llama_init_from_model(model, params); +} + +void llama_free(llama_context * ctx) { + delete ctx; +} + +uint32_t llama_n_ctx(const llama_context * ctx) { + return ctx->n_ctx(); +} + +uint32_t llama_n_batch(const llama_context * ctx) { + return ctx->n_batch(); +} + +uint32_t llama_n_ubatch(const llama_context * ctx) { + return ctx->n_ubatch(); +} + +uint32_t llama_n_seq_max(const llama_context * ctx) { + return ctx->n_seq_max(); +} + +const llama_model * llama_get_model(const llama_context * ctx) { + return &ctx->get_model(); +} + +llama_kv_cache * llama_get_kv_self(llama_context * ctx) { + return ctx->get_kv_self(); +} + +void llama_kv_self_update(llama_context * ctx) { + ctx->kv_self_update(); +} + +enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { + return ctx->pooling_type(); +} + +void llama_attach_threadpool( + llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch) { + ctx->attach_threadpool(threadpool, threadpool_batch); +} + +void llama_detach_threadpool(llama_context * ctx) { + ctx->detach_threadpool(); +} + +void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { + ctx->set_n_threads(n_threads, n_threads_batch); +} + +int32_t llama_n_threads(llama_context * ctx) { + return ctx->n_threads(); +} + +int32_t llama_n_threads_batch(llama_context * ctx) { + return ctx->n_threads_batch(); +} + +void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { + ctx->set_abort_callback(abort_callback, abort_callback_data); +} + +void llama_set_embeddings(llama_context * ctx, bool embeddings) { + ctx->set_embeddings(embeddings); +} + +void llama_set_causal_attn(llama_context * ctx, bool causal_attn) { + ctx->set_causal_attn(causal_attn); +} + +void llama_set_warmup(llama_context * ctx, bool warmup) { + ctx->set_warmup(warmup); +} + +void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) { + ctx->set_cross_attn(cross_attention); +} + +void llama_synchronize(llama_context * ctx) { + ctx->synchronize(); +} + +float * llama_get_logits(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_logits(); +} + +float * llama_get_logits_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_logits_ith(i); +} + +float * llama_get_embeddings(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings(); +} + +float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_ith(i); +} + +float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { + ctx->synchronize(); + + return ctx->get_embeddings_seq(seq_id); +} + +// llama adapter API + +int32_t llama_set_adapter_lora( + llama_context * ctx, + llama_adapter_lora * adapter, + float scale) { + ctx->set_adapter_lora(adapter, scale); + + return 0; +} + +int32_t llama_rm_adapter_lora( + llama_context * ctx, + llama_adapter_lora * adapter) { + bool res = ctx->rm_adapter_lora(adapter); + + return res ? 0 : -1; +} + +void llama_clear_adapter_lora(llama_context * ctx) { + ctx->clear_adapter_lora(); +} + +int32_t llama_apply_adapter_cvec( + llama_context * ctx, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); + + return res ? 0 : -1; +} + +// +// kv cache view +// + +llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) { + const auto * kv = ctx->get_kv_self(); + if (kv == nullptr) { + LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__); + return {}; + } + + return llama_kv_cache_view_init(*kv, n_seq_max); +} + +void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) { + const auto * kv = ctx->get_kv_self(); + if (kv == nullptr) { + LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__); + return; + } + + llama_kv_cache_view_update(view, kv); +} + +// +// kv cache +// + +// deprecated +int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { + return llama_kv_self_n_tokens(ctx); +} + +int32_t llama_kv_self_n_tokens(const llama_context * ctx) { + const auto * kv = ctx->get_kv_self(); + if (!kv) { + return 0; + } + + return kv->get_n_tokens(); +} + +// deprecated +int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { + return llama_kv_self_used_cells(ctx); +} + +int32_t llama_kv_self_used_cells(const llama_context * ctx) { + const auto * kv = ctx->get_kv_self(); + if (!kv) { + return 0; + } + + return kv->get_used_cells(); +} + +// deprecated +void llama_kv_cache_clear(llama_context * ctx) { + llama_kv_self_clear(ctx); +} + +void llama_kv_self_clear(llama_context * ctx) { + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + kv->clear(); +} + +// deprecated +bool llama_kv_cache_seq_rm( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + return llama_kv_self_seq_rm(ctx, seq_id, p0, p1); +} + +bool llama_kv_self_seq_rm( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + auto * kv = ctx->get_kv_self(); + if (!kv) { + return true; + } + + return kv->seq_rm(seq_id, p0, p1); +} + +// deprecated +void llama_kv_cache_seq_cp( + llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_self_seq_cp( + llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +// deprecated +void llama_kv_cache_seq_keep( + llama_context * ctx, + llama_seq_id seq_id) { + return llama_kv_self_seq_keep(ctx, seq_id); +} + +void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + return kv->seq_keep(seq_id); +} + +// deprecated +void llama_kv_cache_seq_add( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); +} + +void llama_kv_self_seq_add( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + return kv->seq_add(seq_id, p0, p1, delta); +} + +// deprecated +void llama_kv_cache_seq_div( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); +} + +void llama_kv_self_seq_div( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + return kv->seq_div(seq_id, p0, p1, d); +} + +// deprecated +llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { + return llama_kv_self_seq_pos_max(ctx, seq_id); +} + +llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { + const auto * kv = ctx->get_kv_self(); + if (!kv) { + return 0; + } + + return kv->seq_pos_max(seq_id); +} + +// deprecated +void llama_kv_cache_defrag(llama_context * ctx) { + return llama_kv_self_defrag(ctx); +} + +void llama_kv_self_defrag(llama_context * ctx) { + auto * kv = ctx->get_kv_self(); + if (!kv) { + return; + } + + return kv->defrag(); +} + +// deprecated +bool llama_kv_cache_can_shift(const llama_context * ctx) { + return llama_kv_self_can_shift(ctx); +} + +bool llama_kv_self_can_shift(const llama_context * ctx) { + const auto * kv = ctx->get_kv_self(); + if (!kv) { + return false; + } + + return kv->get_can_shift(); +} + +// deprecated +void llama_kv_cache_update(llama_context * ctx) { + llama_kv_self_update(ctx); +} + +// llama state API + +// deprecated +size_t llama_get_state_size(llama_context * ctx) { + return llama_state_get_size(ctx); +} + +// deprecated +size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) { + return llama_state_get_data(ctx, dst, -1); +} + +// deprecated +size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) { + return llama_state_set_data(ctx, src, -1); +} + +// deprecated +bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); +} + +// deprecated +bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + return llama_state_save_file(ctx, path_session, tokens, n_token_count); +} + +// Returns the *actual* size of the state. +// Intended to be used when saving to state to a buffer. +size_t llama_state_get_size(llama_context * ctx) { + return ctx->state_get_size(); +} + +size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) { + ctx->synchronize(); + + return ctx->state_get_data(dst, size); +} + +// Sets the state reading from the specified source address +size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) { + ctx->synchronize(); + + return ctx->state_set_data(src, size); +} + +bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + ctx->synchronize(); + + try { + return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); + return false; + } +} + +bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + ctx->synchronize(); + + try { + return ctx->state_save_file(path_session, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); + return false; + } +} + +size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { + return ctx->state_seq_get_size(seq_id); +} + +size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { + ctx->synchronize(); + + return ctx->state_seq_get_data(seq_id, dst, size); +} + +size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { + ctx->synchronize(); + + return ctx->state_seq_set_data(seq_id, src, size); +} + +size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + ctx->synchronize(); + + try { + return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); return 0; } } -size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + ctx->synchronize(); + try { - return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); + return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); return 0; } } -const std::vector> & llama_internal_get_tensor_map( - struct llama_context * ctx -) { - return ctx->model.tensors_by_name; +/// + +int32_t llama_encode( + llama_context * ctx, + llama_batch batch) { + const int ret = ctx->encode(batch); + if (ret != 0) { + LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); + } + + return ret; +} + +int32_t llama_decode( + llama_context * ctx, + llama_batch batch) { + const int ret = ctx->decode(batch); + if (ret != 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; +} + +// +// perf +// + +llama_perf_context_data llama_perf_context(const llama_context * ctx) { + llama_perf_context_data data = {}; + + if (ctx == nullptr) { + return data; + } + + data = ctx->perf_get_data(); + + return data; +} + +void llama_perf_context_print(const llama_context * ctx) { + const auto data = llama_perf_context(ctx); + + const double t_end_ms = 1e-3 * ggml_time_us(); + + LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms); + LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval); + LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); + LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); +} + +void llama_perf_context_reset(llama_context * ctx) { + ctx->perf_reset(); } diff --git a/llama/llama.cpp/src/llama-context.h b/llama/llama.cpp/src/llama-context.h index cf12c9d72..a59ff8fd4 100644 --- a/llama/llama.cpp/src/llama-context.h +++ b/llama/llama.cpp/src/llama-context.h @@ -3,66 +3,216 @@ #include "llama.h" #include "llama-batch.h" #include "llama-cparams.h" -#include "llama-model.h" -#include "llama-kv-cache.h" +#include "llama-graph.h" #include "llama-adapter.h" +#include "llama-kv-cache.h" #include "ggml-cpp.h" #include -#include #include -#include + +struct llama_model; +struct llama_kv_cache; + +class llama_io_read_i; +class llama_io_write_i; struct llama_context { - llama_context(const llama_model & model) - : model(model) - , t_start_us(model.t_start_us) - , t_load_us(model.t_load_us) {} + // init scheduler and compute buffers, reserve worst-case graphs + llama_context( + const llama_model & model, + llama_context_params params); - const struct llama_model & model; + ~llama_context(); - struct llama_cparams cparams; - struct llama_sbatch sbatch; // TODO: revisit if needed - struct llama_kv_cache kv_self; - struct llama_adapter_cvec cvec; + void synchronize(); - std::unordered_map lora; + const llama_model & get_model() const; - std::vector backends; - std::vector> set_n_threads_fns; + uint32_t n_ctx() const; + uint32_t n_ctx_per_seq() const; + uint32_t n_batch() const; + uint32_t n_ubatch() const; + uint32_t n_seq_max() const; - ggml_backend_t backend_cpu = nullptr; + uint32_t n_threads() const; + uint32_t n_threads_batch() const; - ggml_threadpool_t threadpool = nullptr; - ggml_threadpool_t threadpool_batch = nullptr; + llama_kv_cache * get_kv_self(); + const llama_kv_cache * get_kv_self() const; - bool has_evaluated_once = false; + void kv_self_update(); - mutable int64_t t_start_us; - mutable int64_t t_load_us; - mutable int64_t t_p_eval_us = 0; - mutable int64_t t_eval_us = 0; + enum llama_pooling_type pooling_type() const; - mutable int64_t t_compute_start_us = 0; - mutable int64_t n_queued_tokens = 0; + float * get_logits(); + float * get_logits_ith(int32_t i); - mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - mutable int32_t n_eval = 0; // number of eval calls + float * get_embeddings(); + float * get_embeddings_ith(int32_t i); + float * get_embeddings_seq(llama_seq_id seq_id); - // host buffer for the model output (logits and embeddings) - ggml_backend_buffer_ptr buf_output; + void attach_threadpool( + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch); + + void detach_threadpool(); + + void set_n_threads(int32_t n_threads, int32_t n_threads_batch); + + void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); + + void set_embeddings (bool value); + void set_causal_attn(bool value); + void set_warmup(bool value); + void set_cross_attn(bool value); + + void set_adapter_lora( + llama_adapter_lora * adapter, + float scale); + + bool rm_adapter_lora( + llama_adapter_lora * adapter); + + void clear_adapter_lora(); + + bool apply_adapter_cvec( + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end); + + int encode(llama_batch & inp_batch); + int decode(llama_batch & inp_batch); + + // + // state save/load + // + + size_t state_get_size(); + size_t state_get_data( uint8_t * dst, size_t size); + size_t state_set_data(const uint8_t * src, size_t size); + + size_t state_seq_get_size(llama_seq_id seq_id); + size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size); + size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size); + + bool state_load_file( + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + + bool state_save_file( + const char * filepath, + const llama_token * tokens, + size_t n_token_count); + + size_t state_seq_load_file( + llama_seq_id seq_id, + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + + size_t state_seq_save_file( + llama_seq_id seq_id, + const char * filepath, + const llama_token * tokens, + size_t n_token_count); + + // + // perf + // + + llama_perf_context_data perf_get_data() const; + void perf_reset(); + +private: + // + // output + // + + // Make sure enough space is available for outputs. + // Returns max number of outputs for which space was reserved. + int32_t output_reserve(int32_t n_outputs); + + // make the outputs have the same order they had in the user-provided batch + // TODO: maybe remove this + void output_reorder(); + + // + // graph + // + + int32_t graph_max_nodes() const; + + // zero-out inputs and create the ctx_compute for the compute graph + ggml_cgraph * graph_init(); + + llm_graph_result_ptr graph_build( + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype); + + // returns the result of ggml_backend_sched_graph_compute_async execution + ggml_status graph_compute( + ggml_cgraph * gf, + bool batched); + + llm_graph_cb graph_get_cb() const; + + // used by kv_self_update() + ggml_tensor * build_rope_shift( + ggml_context * ctx0, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale, + ggml_backend_buffer * bbuf) const; + + llm_graph_result_ptr build_kv_self_shift( + ggml_context * ctx0, + ggml_cgraph * gf) const; + + llm_graph_result_ptr build_kv_self_defrag( + ggml_context * ctx0, + ggml_cgraph * gf, + const std::vector & moves) const; + + // TODO: read/write lora adapters and cvec + size_t state_write_data(llama_io_write_i & io); + size_t state_read_data (llama_io_read_i & io); + + size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id); + size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id); + + // + // members + // + + const llama_model & model; + + llama_cparams cparams; + llama_adapter_cvec cvec; + llama_adapter_loras loras; + llama_sbatch sbatch; + + llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably + + std::unique_ptr kv_self; + + // TODO: remove + bool logits_all = false; // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; - std::vector output_ids; // map batch token positions to ids of the logits and embd buffers - size_t output_size = 0; // capacity (of tokens positions) for the output buffers - int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch - - bool logits_all = false; - // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings @@ -72,59 +222,47 @@ struct llama_context { // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; - // whether we are computing encoder output or decoder output - bool is_encoding = false; + int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch + int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers - // TODO: find a better way to accommodate mutli-dimension position encoding methods - // number of position id each token get, 1 for each token in most cases. - // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate. - int n_pos_per_token = 1; + std::vector output_ids; // map batch token positions to ids of the logits and embd buffers - // output of the encoder part of the encoder-decoder models - std::vector embd_enc; - std::vector> seq_ids_enc; - - // memory buffers used to evaluate the model - std::vector buf_compute_meta; ggml_backend_sched_ptr sched; + ggml_backend_t backend_cpu = nullptr; + std::vector backends; + + ggml_context_ptr ctx_compute; + + ggml_threadpool_t threadpool = nullptr; + ggml_threadpool_t threadpool_batch = nullptr; + ggml_abort_callback abort_callback = nullptr; void * abort_callback_data = nullptr; - // input tensors - struct ggml_tensor * inp_tokens; // I32 [n_batch] - struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] - struct ggml_tensor * inp_pos; // I32 [n_batch] - struct ggml_tensor * inp_out_ids; // I32 [n_outputs] - struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] - struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch] - struct ggml_tensor * inp_K_shift; // I32 [kv_size] - struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] - struct ggml_tensor * inp_cls; // I32 [n_batch] - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] - struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] - struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] - struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + std::vector> set_n_threads_fns; - struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061] + // buffer types used for the compute buffer of each backend + std::vector backend_ptrs; + std::vector backend_buft; + + // memory buffers used to evaluate the model + std::vector buf_compute_meta; + + // host buffer for the model output (logits and embeddings) + ggml_backend_buffer_ptr buf_output; + + bool has_evaluated_once = false; + + // perf + mutable int64_t t_start_us = 0; + mutable int64_t t_load_us = 0; + mutable int64_t t_p_eval_us = 0; + mutable int64_t t_eval_us = 0; + + mutable int64_t t_compute_start_us = 0; + mutable int64_t n_queued_tokens = 0; + + mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + mutable int32_t n_eval = 0; // number of eval calls }; - -// TODO: make these methods of llama_context -void llama_set_k_shift(struct llama_context & lctx); - -void llama_set_s_copy(struct llama_context & lctx); - -void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch); - -// Make sure enough space is available for outputs. -// Returns max number of outputs for which space was reserved. -size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs); - -// make the outputs have the same order they had in the user-provided batch -void llama_output_reorder(struct llama_context & ctx); - -// For internal test use -// TODO: remove -const std::vector> & llama_internal_get_tensor_map(struct llama_context * ctx); diff --git a/llama/llama.cpp/src/llama-cparams.h b/llama/llama.cpp/src/llama-cparams.h index 9681e5a08..85ad91b9b 100644 --- a/llama/llama.cpp/src/llama-cparams.h +++ b/llama/llama.cpp/src/llama-cparams.h @@ -30,6 +30,7 @@ struct llama_cparams { bool flash_attn; bool no_perf; bool cross_attn; + bool warmup; enum llama_pooling_type pooling_type; diff --git a/llama/llama.cpp/src/llama-grammar.cpp b/llama/llama.cpp/src/llama-grammar.cpp index 98af1ba39..973b47ae0 100644 --- a/llama/llama.cpp/src/llama-grammar.cpp +++ b/llama/llama.cpp/src/llama-grammar.cpp @@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl( /* .awaiting_trigger = */ false, /* .trigger_buffer = */ "", /* .trigger_tokens = */ {}, - /* .trigger_words = */ {}, + /* .trigger_patterns = */ {}, }; } @@ -978,19 +978,15 @@ struct llama_grammar * llama_grammar_init_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_patterns, + size_t num_trigger_patterns, const llama_token * trigger_tokens, size_t num_trigger_tokens) { llama_grammar_parser parser; // if there is a grammar, parse it - if (!parser.parse(grammar_str)) { - return nullptr; - } - - // will be empty (default) if there are parse errors - if (parser.rules.empty()) { + // rules will be empty (default) if there are parse errors + if (!parser.parse(grammar_str) || parser.rules.empty()) { fprintf(stderr, "%s: failed to parse grammar\n", __func__); return nullptr; } @@ -1054,14 +1050,16 @@ struct llama_grammar * llama_grammar_init_impl( } while (true); std::vector vec_trigger_tokens; - std::vector vec_trigger_words; + std::vector vec_trigger_patterns; for (size_t i = 0; i < num_trigger_tokens; i++) { GGML_ASSERT(trigger_tokens != nullptr); vec_trigger_tokens.push_back(trigger_tokens[i]); } - for (size_t i = 0; i < num_trigger_words; i++) { - GGML_ASSERT(trigger_words != nullptr); - vec_trigger_words.push_back(trigger_words[i]); + for (size_t i = 0; i < num_trigger_patterns; i++) { + GGML_ASSERT(trigger_patterns != nullptr); + auto & trigger = vec_trigger_patterns.emplace_back(); + trigger.pattern = trigger_patterns[i]; + trigger.regex = std::regex(trigger.pattern); } // Important: vec_rules has to be moved here, not copied, because stacks contains @@ -1076,7 +1074,7 @@ struct llama_grammar * llama_grammar_init_impl( /* .awaiting_trigger = */ lazy, /* .trigger_buffer = */ "", std::move(vec_trigger_tokens), - std::move(vec_trigger_words), + std::move(vec_trigger_patterns), }; } @@ -1089,7 +1087,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { } struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { - llama_grammar * result = new llama_grammar { + auto * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, @@ -1098,7 +1096,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.awaiting_trigger, grammar.trigger_buffer, grammar.trigger_tokens, - grammar.trigger_words, + grammar.trigger_patterns, }; // redirect elements in stacks to point to new rules @@ -1173,16 +1171,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; } else { - // TODO: consider a smarter incremental substring search algorithm (store last position to search from). grammar.trigger_buffer += piece; - for (const auto & word : grammar.trigger_words) { - auto pos = grammar.trigger_buffer.find(word); - if (pos != std::string::npos) { + + std::smatch match; + for (const auto & trigger_pattern : grammar.trigger_patterns) { + if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { grammar.awaiting_trigger = false; - auto constrained_str = grammar.trigger_buffer.substr(pos); + // get from the first match to the end of the string + auto constrained_str = grammar.trigger_buffer.substr(match.position(1)); + // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, constrained_str); - LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str()); + LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); return; } } diff --git a/llama/llama.cpp/src/llama-grammar.h b/llama/llama.cpp/src/llama-grammar.h index b143d834c..f8c291de9 100644 --- a/llama/llama.cpp/src/llama-grammar.h +++ b/llama/llama.cpp/src/llama-grammar.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include #include #include @@ -105,6 +106,11 @@ struct llama_grammar_parser { void print(FILE * file); }; +struct llama_grammar_trigger_pattern { + std::string pattern; + std::regex regex; +}; + struct llama_grammar { // note: allow null vocab for testing (not great) const llama_vocab * vocab; @@ -122,7 +128,10 @@ struct llama_grammar { bool awaiting_trigger = false; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). - std::vector trigger_words; + std::vector + trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated + // string, and the grammar will be given the string from the first match group onwards. + }; // @@ -141,8 +150,8 @@ struct llama_grammar * llama_grammar_init_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_patterns, + size_t num_trigger_patterns, const llama_token * trigger_tokens, size_t num_trigger_tokens); diff --git a/llama/llama.cpp/src/llama-graph.cpp b/llama/llama.cpp/src/llama-graph.cpp new file mode 100644 index 000000000..83f3c5a86 --- /dev/null +++ b/llama/llama.cpp/src/llama-graph.cpp @@ -0,0 +1,1720 @@ +#include "llama-graph.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-cparams.h" +#include "llama-kv-cache.h" + +#include +#include +#include + +static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + n_buckets >>= 1; + } + + const int64_t max_exact = n_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + + if (bidirectional) { + relative_bucket += (relative_position > 0) * n_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + + return relative_bucket; +} + +void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { + if (ubatch->token) { + const int64_t n_tokens = ubatch->n_tokens; + + ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens)); + } + + if (ubatch->embd) { + const int64_t n_embd = embd->ne[0]; + const int64_t n_tokens = ubatch->n_tokens; + + ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd)); + } +} + +void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { + if (ubatch->pos && pos) { + const int64_t n_tokens = ubatch->n_tokens; + + ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos)); + } +} + +void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { + if (ubatch->pos && attn_scale) { + const int64_t n_tokens = ubatch->n_tokens; + + std::vector attn_scale_data(n_tokens, 0.0f); + for (int i = 0; i < n_tokens; ++i) { + const float pos = ubatch->pos[i]; + attn_scale_data[i] = std::log( + std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0 + ) * f_attn_temp_scale + 1.0; + } + + ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale)); + } +} + +void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { + if (pos_bucket) { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + int32_t * data = (int32_t *) pos_bucket->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); + } + } + } + } +} + +void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { + if (pos_bucket) { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + int32_t * data = (int32_t *) pos_bucket->data; + + const int64_t n_kv = kv_self->n; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false); + } + } + } + } +} + +void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) { + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + //GGML_ASSERT(out_ids && "every model that can must skip unused outputs"); + + if (!out_ids) { + LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__); + } else { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer)); + int32_t * data = (int32_t *) out_ids->data; + + if (n_outputs == n_tokens) { + for (int i = 0; i < n_tokens; ++i) { + data[i] = i; + } + } else if (ubatch->output) { + int32_t n_outputs = 0; + for (int i = 0; i < n_tokens; ++i) { + if (ubatch->output[i]) { + data[n_outputs++] = i; + } + } + // the graph needs to have been passed the correct number of outputs + GGML_ASSERT(n_outputs == n_outputs); + } else if (n_outputs == 1) { + // only keep last output + data[0] = n_tokens - 1; + } else { + GGML_ASSERT(n_outputs == 0); + } + } + } +} + +void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(mean); + GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer)); + + float * data = (float *) mean->data; + memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean)); + + std::vector sum(n_tokens, 0); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); + + sum[seq_id] += ubatch->n_seq_tokens; + } + + std::vector div(n_tokens, 0.0f); + for (int i = 0; i < n_tokens; ++i) { + const uint64_t s = sum[i]; + if (s > 0) { + div[i] = 1.0f/float(s); + } + } + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } + } + } +} + +void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(cls); + GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); + + uint32_t * data = (uint32_t *) cls->data; + memset(cls->data, 0, n_tokens * ggml_element_size(cls)); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch->pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } + } + } + } + + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(cls); + GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); + + uint32_t * data = (uint32_t *) cls->data; + memset(cls->data, 0, n_tokens * ggml_element_size(cls)); + + std::vector last_pos(n_tokens, -1); + std::vector last_row(n_tokens, -1); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch->pos[s*n_seq_tokens + i]; + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = s*n_seq_tokens + i; + } + } + } + + for (int i = 0; i < n_tokens; ++i) { + if (last_row[i] >= 0) { + data[i] = last_row[i]; + } + } + } +} + +void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + const int64_t n_kv = kv_self->n; + + if (s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); + int32_t * data = (int32_t *) s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t cell_id = i + kv_self->head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; + + // prevent out-of-bound sources + if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) { + kv_cell.src = cell_id; + } + + data[i] = kv_cell.src; + + // TODO: do not mutate the KV cache + // ensure copy only happens once + if (kv_cell.src != (int32_t) cell_id) { + kv_cell.src = cell_id; + } + } + } +} + +void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + const int64_t n_kv = kv_self->n; + + if (s_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer)); + float * data = (float *) s_mask->data; + + // clear unused states + for (int i = 0; i < n_kv; ++i) { + const uint32_t cell_id = i + kv_self->head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; + + data[i] = (float) (kv_cell.src >= 0); + + // only clear once + if (kv_cell.src < 0) { + kv_cell.src = cell_id; + } + } + } +} + +void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (cross_embd && !cross->v_embd.empty()) { + assert(cross_embd->type == GGML_TYPE_F32); + + ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd)); + } +} + +void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { + if (kq_mask) { + if (cparams.causal_attn) { + const int64_t n_kv = ubatch->n_tokens; + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); + float * data = (float *) kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch->seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { + if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) { + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f; + } + } + } + } + } + } else { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + const int64_t n_stride = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); + + float * data = (float *) kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch->seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { + if (ubatch->seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; + } + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } + } + } + } + } + } +} + +void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { + if (self_kq_mask || self_kq_mask_swa) { + const int64_t n_kv = kv_self->n; + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + float * data = nullptr; + float * data_swa = nullptr; + + if (self_kq_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + data = (float *) self_kq_mask->data; + } + + if (self_kq_mask_swa) { + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); + data_swa = (float *) self_kq_mask_swa->data; + } + + // Use only the previous KV cells of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: + // Causal mask: + // xxx------- + // xxxx------ + // xxxxx----- + // Non-causal mask: + // xxxxx----- + // xxxxx----- + // xxxxx----- + // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = ubatch->pos[s*n_seq_tokens + j]; + for (int i = 0; i < n_kv; ++i) { + float f; + // mask the token if: + if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence + || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens + ) { + f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(kv_self->cells[i].pos - pos); + } else { + f = 0.0f; + } + } + + if (data) { + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + + // may need to cut off old tokens for sliding window + // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask" + if (data_swa) { + if (hparams.n_attn_chunk) { + llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; + if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + f = -INFINITY; + } + } else { + if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + } + data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + } + } + } + + // mask padded tokens + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + + // mask padded tokens + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + } + } +} + +void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { + if (cross_kq_mask) { + const int64_t n_enc = cross_kq_mask->ne[0]; + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + float * data = (float *) cross_kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_enc; ++i) { + float f = -INFINITY; + for (int s = 0; s < ubatch->n_seq_id[j]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[j][s]; + if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) { + f = 0.0f; + } + } + data[h*(n_enc*n_tokens) + j*n_enc + i] = f; + } + } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_enc; ++j) { + data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY; + } + } + } + } +} + +void llm_graph_input_cross_attn_state::set_input(const llama_ubatch * ubatch) { + if (ubatch->embd) { + ggml_backend_tensor_set(cross_attn_state, ubatch->embd, 0, ggml_nbytes(cross_attn_state)); + } +} + +// +// llm_graph_context +// + +llm_graph_context::llm_graph_context(const llm_graph_params & params) : + arch (params.arch), + hparams (params.hparams), + cparams (params.cparams), + ubatch (params.ubatch), + n_embd (hparams.n_embd), + n_layer (hparams.n_layer), + n_rot (hparams.n_rot), + n_ctx (cparams.n_ctx), + n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max), + n_head (hparams.n_head()), + n_head_kv (hparams.n_head_kv()), + n_embd_head_k (hparams.n_embd_head_k), + n_embd_k_gqa (hparams.n_embd_k_gqa()), + n_embd_head_v (hparams.n_embd_head_v), + n_embd_v_gqa (hparams.n_embd_v_gqa()), + n_expert (hparams.n_expert), + n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used), + freq_base (cparams.rope_freq_base), + freq_scale (cparams.rope_freq_scale), + ext_factor (cparams.yarn_ext_factor), + attn_factor (cparams.yarn_attn_factor), + beta_fast (cparams.yarn_beta_fast), + beta_slow (cparams.yarn_beta_slow), + norm_eps (hparams.f_norm_eps), + norm_rms_eps (hparams.f_norm_rms_eps), + n_tokens (ubatch.n_tokens), + n_outputs (params.n_outputs), + n_ctx_orig (cparams.n_ctx_orig_yarn), + pooling_type (cparams.pooling_type), + rope_type (hparams.rope_type), + ctx0 (params.ctx), + sched (params.sched), + backend_cpu (params.backend_cpu), + cvec (params.cvec), + loras (params.loras), + memory (params.memory), + cross (params.cross), + cb_func (params.cb), + res (std::make_unique()) { + } + +int64_t llm_graph_context::n_pos_per_token() const { + return arch == LLM_ARCH_QWEN2VL ? 4 : 1; +} + +void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { + if (cb_func) { + cb_func(ubatch, cur, name, il); + } +} + +ggml_tensor * llm_graph_context::build_cvec( + ggml_tensor * cur, + int il) const { + return cvec->apply_to(ctx0, cur, il); +} + +ggml_tensor * llm_graph_context::build_lora_mm( + ggml_tensor * w, + ggml_tensor * cur) const { + ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); + + for (const auto & lora : *loras) { + llama_adapter_lora_weight * lw = lora.first->get_weight(w); + if (lw == nullptr) { + continue; + } + + const float adapter_scale = lora.second; + const float scale = lw->get_scale(lora.first->alpha, adapter_scale); + + ggml_tensor * ab_cur = ggml_mul_mat( + ctx0, lw->b, + ggml_mul_mat(ctx0, lw->a, cur) + ); + + ab_cur = ggml_scale(ctx0, ab_cur, scale); + res = ggml_add(ctx0, res, ab_cur); + } + + return res; +} + +ggml_tensor * llm_graph_context::build_lora_mm_id( + ggml_tensor * w, // ggml_tensor * as + ggml_tensor * cur, // ggml_tensor * b + ggml_tensor * ids) const { + ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids); + for (const auto & lora : *loras) { + llama_adapter_lora_weight * lw = lora.first->get_weight(w); + if (lw == nullptr) { + continue; + } + + const float alpha = lora.first->alpha; + const float rank = (float) lw->b->ne[0]; + const float scale = alpha ? lora.second * alpha / rank : lora.second; + + ggml_tensor * ab_cur = ggml_mul_mat_id( + ctx0, lw->b, + ggml_mul_mat_id(ctx0, lw->a, cur, ids), + ids + ); + + ab_cur = ggml_scale(ctx0, ab_cur, scale); + res = ggml_add(ctx0, res, ab_cur); + } + + return res; +} + +ggml_tensor * llm_graph_context::build_norm( + ggml_tensor * cur, + ggml_tensor * mw, + ggml_tensor * mb, + llm_norm_type type, + int il) const { + switch (type) { + case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break; + case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break; + case LLM_NORM_GROUP: + { + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]); + cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]); + } break; + } + + if (mw || mb) { + cb(cur, "norm", il); + } + + if (mw) { + cur = ggml_mul(ctx0, cur, mw); + if (mb) { + cb(cur, "norm_w", il); + } + } + + if (mb) { + cur = ggml_add(ctx0, cur, mb); + } + + return cur; +} + +ggml_tensor * llm_graph_context::build_ffn( + ggml_tensor * cur, + ggml_tensor * up, + ggml_tensor * up_b, + ggml_tensor * up_s, + ggml_tensor * gate, + ggml_tensor * gate_b, + ggml_tensor * gate_s, + ggml_tensor * down, + ggml_tensor * down_b, + ggml_tensor * down_s, + ggml_tensor * act_scales, + llm_ffn_op_type type_op, + llm_ffn_gate_type type_gate, + int il) const { + ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur; + cb(tmp, "ffn_up", il); + + if (up_b) { + tmp = ggml_add(ctx0, tmp, up_b); + cb(tmp, "ffn_up_b", il); + } + + if (up_s) { + tmp = ggml_mul(ctx0, tmp, up_s); + cb(tmp, "ffn_up_s", il); + } + + if (gate) { + switch (type_gate) { + case LLM_FFN_SEQ: + { + cur = build_lora_mm(gate, tmp); + cb(cur, "ffn_gate", il); + } break; + case LLM_FFN_PAR: + { + cur = build_lora_mm(gate, cur); + cb(cur, "ffn_gate", il); + } break; + } + + if (gate_b) { + cur = ggml_add(ctx0, cur, gate_b); + cb(cur, "ffn_gate_b", il); + } + + if (gate_s) { + cur = ggml_mul(ctx0, cur, gate_s); + cb(cur, "ffn_gate_s", il); + } + + } else { + cur = tmp; + } + + switch (type_op) { + case LLM_FFN_SILU: + { + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_silu", il); + } break; + case LLM_FFN_GELU: + { + cur = ggml_gelu(ctx0, cur); + cb(cur, "ffn_gelu", il); + if (act_scales != NULL) { + cur = ggml_div(ctx0, cur, act_scales); + cb(cur, "ffn_act", il); + } + } break; + case LLM_FFN_RELU: + { + cur = ggml_relu(ctx0, cur); + cb(cur, "ffn_relu", il); + } break; + case LLM_FFN_RELU_SQR: + { + cur = ggml_relu(ctx0, cur); + cb(cur, "ffn_relu", il); + + cur = ggml_sqr(ctx0, cur); + cb(cur, "ffn_sqr(relu)", il); + } break; + case LLM_FFN_SWIGLU: + { + // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + int64_t split_point = cur->ne[0] / 2; + ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); + ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); + + x0 = ggml_silu(ctx0, x0); + cb(cur, "ffn_silu", il); + + cur = ggml_mul(ctx0, x0, x1); + cb(cur, "ffn_mul", il); + } break; + } + + if (type_gate == LLM_FFN_PAR) { + cur = ggml_mul(ctx0, cur, tmp); + cb(cur, "ffn_gate_par", il); + } + + if (down) { + cur = build_lora_mm(down, cur); + } + + if (down_b) { + cb(cur, "ffn_down", il); + } + + if (down_b) { + cur = ggml_add(ctx0, cur, down_b); + } + + if (down_s) { + cur = ggml_mul(ctx0, cur, down_s); + cb(cur, "ffn_down_s", il); + } + + return cur; +} + +ggml_tensor * llm_graph_context::build_moe_ffn( + ggml_tensor * cur, + ggml_tensor * gate_inp, + ggml_tensor * up_exps, + ggml_tensor * gate_exps, + ggml_tensor * down_exps, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + bool scale_w, + float w_scale, + llama_expert_gating_func_type gating_op, + int il) const { + const int64_t n_embd = cur->ne[0]; + const int64_t n_tokens = cur->ne[1]; + const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN + + ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + ggml_tensor * probs = nullptr; + switch (gating_op) { + case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: + { + probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens] + } break; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: + { + probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] + } break; + default: + GGML_ABORT("fatal error"); + } + cb(probs, "ffn_moe_probs", il); + + // add experts selection bias - introduced in DeepSeek V3 + // leave probs unbiased as it's later used to get expert weights + ggml_tensor * selection_probs = probs; + if (exp_probs_b != nullptr) { + selection_probs = ggml_add(ctx0, probs, exp_probs_b); + cb(selection_probs, "ffn_moe_probs_biased", il); + } + + // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k + // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198 + if (arch == LLM_ARCH_LLAMA4) { + selection_probs = logits; + } + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + cb(selected_experts, "ffn_moe_topk", il); + + ggml_tensor * weights = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights", il); + + if (norm_w) { + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); + + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights_norm", il); + + weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); + } + if (scale_w) { + weights = ggml_scale(ctx0, weights, w_scale); + cb(weights, "ffn_moe_weights_scaled", il); + } + + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + + if (weight_before_ffn) { + // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d) + ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens); + repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens] + cur = ggml_mul(ctx0, repeated, weights); + cb(cur, "ffn_moe_weighted", il); + } + + ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); + + switch (type_op) { + case LLM_FFN_SILU: + { + gate = ggml_silu(ctx0, gate); + cb(gate, "ffn_moe_silu", il); + } break; + case LLM_FFN_GELU: + { + gate = ggml_gelu(ctx0, gate); + cb(gate, "ffn_moe_gelu", il); + } break; + default: + GGML_ABORT("fatal error"); + } + + ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens] + cb(par, "ffn_moe_gate_par", il); + + ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] + cb(experts, "ffn_moe_down", il); + + if (!weight_before_ffn) { + experts = ggml_mul(ctx0, experts, weights); + cb(cur, "ffn_moe_weighted", il); + } + + // aggregate experts + ggml_tensor * moe_out = nullptr; + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, + experts->nb[2], i*experts->nb[1]); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx0, moe_out, cur_expert); + } + } + + if (n_expert_used == 1) { + // avoid returning a non-contiguous tensor + moe_out = ggml_cont(ctx0, moe_out); + } + + cb(moe_out, "ffn_moe_out", il); + + return moe_out; +} + +// input embeddings with optional lora +ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { + const int64_t n_embd = hparams.n_embd; + + auto inp = std::make_unique(); + + ggml_tensor * cur = nullptr; + + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + //cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + + cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); + + // apply lora for embedding tokens if needed + for (const auto & lora : *loras) { + llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd); + if (lw == nullptr) { + continue; + } + + const float adapter_scale = lora.second; + const float scale = lw->get_scale(lora.first->alpha, adapter_scale); + + ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat( + ctx0, lw->b, // non-transposed lora_b + ggml_get_rows(ctx0, lw->a, inp->tokens) + ), scale); + + cur = ggml_add(ctx0, cur, inpL_delta); + } + } else { + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); + ggml_set_input(inp->embd); + + cur = inp->embd; + } + + // For Granite architecture + if (hparams.f_embedding_scale != 0.0f) { + cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); + } + + cb(cur, "inp_embd", -1); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_pos() const { + auto inp = std::make_unique(n_pos_per_token()); + + auto & cur = inp->pos; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token()); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_attn_scale() const { + auto inp = std::make_unique(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); + + auto & cur = inp->attn_scale; + + cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token()); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_out_ids() const { + auto inp = std::make_unique(hparams, cparams, n_outputs); + + auto & cur = inp->out_ids; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_mean() const { + auto inp = std::make_unique(cparams); + + auto & cur = inp->mean; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_cls() const { + auto inp = std::make_unique(cparams); + + auto & cur = inp->cls; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_s_copy() const { + const llama_kv_cache_unified * kv_self = static_cast(memory); + + auto inp = std::make_unique(kv_self); + + const auto n_kv = kv_self->n; + + auto & cur = inp->s_copy; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_s_mask() const { + const llama_kv_cache_unified * kv_self = static_cast(memory); + + auto inp = std::make_unique(kv_self); + + const auto n_kv = kv_self->n; + + auto & cur = inp->s_mask; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_cross_embd() const { + auto inp = std::make_unique(cross); + + auto & cur = inp->cross_embd; + + // if we have the output embeddings from the encoder, use them directly + // TODO: needs more work to be correct, for now just use the tensor shape + //if (cross->t_embd) { + // cur = ggml_view_tensor(ctx0, cross->t_embd); + + // return cur; + //} + + const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd; + const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { + auto inp = std::make_unique(hparams); + + auto & cur = inp->pos_bucket; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { + const llama_kv_cache_unified * kv_self = static_cast(memory); + + auto inp = std::make_unique(hparams, kv_self); + + const auto n_kv = kv_self->n; + + auto & cur = inp->pos_bucket; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const { + ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]); + cb(pos_bucket_1d, "pos_bucket_1d", -1); + + ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d); + + pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]); + pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3); + pos_bias = ggml_cont (ctx0, pos_bias); + + cb(pos_bias, "pos_bias", -1); + + return pos_bias; +} + +ggml_tensor * llm_graph_context::build_attn_mha( + ggml_cgraph * gf, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * kq_b, + ggml_tensor * kq_mask, + bool v_trans, + float kq_scale) const { + //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + //const int64_t n_head = hparams.n_head(il); + //const int64_t n_head_kv = hparams.n_head_kv(il); + + //const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; + + const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0]; + + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; + const auto n_kv = k->ne[1]; + + ggml_tensor * cur; + + // TODO: replace hardcoded padding with ggml-provided padding + if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) { + GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); + + if (v_trans) { + v = ggml_transpose(ctx0, v); + } + + // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn) + if (k->type == GGML_TYPE_F32) { + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + } + + if (v->type == GGML_TYPE_F32) { + v = ggml_cast(ctx0, v, GGML_TYPE_F16); + } + + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + + cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); + } else { + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + + // note: this op tends to require high floating point range + // while for some models F16 is enough, for others it is not, so we default to F32 here + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + if (arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below + + kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx0, kq, 30); + } + + if (hparams.attn_soft_cap) { + kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping); + kq = ggml_tanh (ctx0, kq); + kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); + } + + if (kq_b) { + kq = ggml_add(ctx0, kq, kq_b); + } + + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + + if (!v_trans) { + // note: avoid this branch + v = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + } + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + + ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + + if (!cparams.offload_kqv) { + // all nodes between the KV store and the attention output are run on the CPU + ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu); + } + } + + ggml_build_forward_expand(gf, cur); + + return cur; +} + +llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const { + auto inp = std::make_unique(hparams, cparams); + + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch + inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->kq_mask); + + inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; + + return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_no_cache * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + float kq_scale, + int il) const { + GGML_UNUSED(n_tokens); + + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + //cb(q, "q", il); + + ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); + //cb(k, "k", il); + + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); + //cb(k, "v", il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { + const llama_kv_cache_unified * kv_self = static_cast(memory); + + auto inp = std::make_unique(hparams, cparams, kv_self); + + const auto n_kv = kv_self->n; + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + + if (hparams.n_swa_pattern > 1) { + GGML_ASSERT(hparams.n_swa > 0); + + inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(inp->self_kq_mask_swa); + + inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + } + + return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_unified * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto & n_ctx = cparams.n_ctx; + + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + const auto n_tokens = q_cur->ne[2]; + + const bool v_trans = !cparams.flash_attn; + + // store to KV cache + { + GGML_ASSERT(!kv_self->recurrent); + + const auto kv_head = kv_self->head; + + GGML_ASSERT(kv_self->size == n_ctx); + + ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head); + //cb(k_cache_view, "k_cache_view", il); + + // note: storing RoPE-ed version of K in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); + + v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens); + + ggml_tensor * v_cache_view = nullptr; + + if (!v_trans) { + v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head); + } else { + // note: the V cache is transposed when not using flash attention + v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(kv_self->v_l[il]), + (kv_head)*ggml_element_size(kv_self->v_l[il])); + + v_cur = ggml_transpose(ctx0, v_cur); + } + //cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); + } + + const bool is_swa = hparams.is_swa(il); + + const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); + + const auto n_kv = kv_self->n; + + const int64_t n_head_kv = hparams.n_head_kv(il); + + const auto & n_embd_head_k = hparams.n_embd_head_k; + const auto & n_embd_head_v = hparams.n_embd_head_v; + + ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + //cb(q, "q", il); + + ggml_tensor * k = + ggml_view_3d(ctx0, kv_self->k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), + 0); + //cb(k, "k", il); + + ggml_tensor * v = !v_trans ? + ggml_view_3d(ctx0, kv_self->v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v), + 0) : + ggml_view_3d(ctx0, kv_self->v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self->v_l[il])*n_ctx, + ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, + 0); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { + auto inp = std::make_unique(cross); + + const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + + inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(inp->cross_kq_mask); + + inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; + + return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_inp_cross_attn_state() const { + const int64_t n_embd = hparams.n_embd; + + auto inp = std::make_unique(); + + ggml_tensor * cur = nullptr; + + inp->cross_attn_state = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, 1601, 4); + ggml_set_input(inp->cross_attn_state); + + cur = inp->cross_attn_state; + + cb(cur, "inp_cross_attn_state", -1); + + res->add_input(std::move(inp)); + + return cur; +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_cross * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto & kq_mask = inp->get_kq_mask_cross(); + + ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + //cb(q, "q", il); + + ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); + //cb(k, "k", il); + + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); + //cb(k, "v", il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +ggml_tensor * llm_graph_context::build_copy_mask_state( + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + int32_t n_state, + int32_t n_seqs) const { + const llama_kv_cache_unified * kv_self = static_cast(memory); + + const auto n_kv = kv_self->n; + const auto kv_head = kv_self->head; + + ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size); + + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // this shrinks the tensors's ne[1] to n_kv + states = ggml_get_rows(ctx0, states, state_copy); + + // clear states of sequences which are starting at the beginning of this batch + // FIXME: zero-out NANs? + states = ggml_mul(ctx0, states, state_mask); + + // copy states which won't be changed further (between n_seqs and n_kv) + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)), + ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); + + // the part of the states that will be used and modified + return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0); +} + +ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( + ggml_cgraph * gf, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il) const { + const llama_kv_cache_unified * kv_self = static_cast(memory); + + const auto token_shift_count = hparams.token_shift_count; + + const int64_t n_seqs = ubatch.n_seqs; + + ggml_tensor * token_shift_all = kv_self->k_l[il]; + + ggml_tensor * token_shift = build_copy_mask_state( + gf, token_shift_all, state_copy, state_mask, + hparams.n_embd_k_s(), n_seqs); + + token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); + + return token_shift; +} + +ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( + ggml_tensor * token_shift, + const llama_ubatch & ubatch, + int il) const { + const llama_kv_cache_unified * kv_self = static_cast(memory); + + const auto token_shift_count = hparams.token_shift_count; + const auto n_embd = hparams.n_embd; + + const int64_t n_seqs = ubatch.n_seqs; + + const auto kv_head = kv_self->head; + + return ggml_cpy( + ctx0, + ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), + ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il])) + ); +} + +void llm_graph_context::build_pooling( + ggml_cgraph * gf, + ggml_tensor * cls, + ggml_tensor * cls_b, + ggml_tensor * cls_out, + ggml_tensor * cls_out_b) const { + if (!cparams.embeddings) { + return; + } + + ggml_tensor * inp = res->t_embd; + + //// find result_norm tensor for input + //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) { + // inp = ggml_graph_node(gf, i); + // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) { + // break; + // } + + // inp = nullptr; + //} + + GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor"); + + ggml_tensor * cur; + + switch (pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + cur = inp; + } break; + case LLAMA_POOLING_TYPE_MEAN: + { + ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } break; + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } break; + case LLAMA_POOLING_TYPE_RANK: + { + ggml_tensor * inp_cls = build_inp_cls(); + inp = ggml_get_rows(ctx0, inp, inp_cls); + + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + GGML_ASSERT(cls != nullptr); + GGML_ASSERT(cls_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b); + cur = ggml_tanh(ctx0, cur); + + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 + if (cls_out) { + GGML_ASSERT(cls_out_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b); + } + } break; + default: + { + GGML_ABORT("unknown pooling type"); + } + } + + cb(cur, "result_embd_pooled", -1); + res->t_embd_pooled = cur; + + ggml_build_forward_expand(gf, cur); +} + diff --git a/llama/llama.cpp/src/llama-graph.h b/llama/llama.cpp/src/llama-graph.h new file mode 100644 index 000000000..51993998a --- /dev/null +++ b/llama/llama.cpp/src/llama-graph.h @@ -0,0 +1,604 @@ +#pragma once + +#include "llama-arch.h" +#include "llama-hparams.h" +#include "llama-adapter.h" + +#include +#include +#include +#include +#include + +struct ggml_cgraph; +struct ggml_context; +struct ggml_tensor; + +struct llama_ubatch; +struct llama_cparams; + +class llama_memory_i; +class llama_kv_cache_unified; + +// certain models (typically multi-modal) can produce different types of graphs +enum llm_graph_type { + LLM_GRAPH_TYPE_DEFAULT, + LLM_GRAPH_TYPE_ENCODER, + LLM_GRAPH_TYPE_DECODER, +}; + +enum llm_ffn_op_type { + LLM_FFN_SILU, + LLM_FFN_GELU, + LLM_FFN_RELU, + LLM_FFN_RELU_SQR, + LLM_FFN_SWIGLU, +}; + +enum llm_ffn_gate_type { + LLM_FFN_SEQ, + LLM_FFN_PAR, // ffn_gate is parallel to ffn_up +}; + +enum llm_norm_type { + LLM_NORM, + LLM_NORM_RMS, + LLM_NORM_GROUP, +}; + +// TODO: tmp - need something better to pass the data from the encoder to the decoder +struct llama_cross { + // the output embeddings from the encoder as a ggml tensor + // TODO: this needs more work to be correct, for now copy the embeddings data to host memory + // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524 + //ggml_tensor * t_embd = nullptr; + + int64_t n_embd = 0; + int64_t n_enc = 0; + + // embeddings data copied to host memory (tmp) + std::vector v_embd; + + // needed to construct the cross-attention mask in the decoder + std::vector> seq_ids_enc; +}; + +// +// llm_graph_input +// + +class llm_graph_input_i { +public: + virtual ~llm_graph_input_i() = default; + + virtual void set_input(const llama_ubatch * ubatch) = 0; +}; + +using llm_graph_input_ptr = std::unique_ptr; + + +class llm_graph_input_embd : public llm_graph_input_i { +public: + llm_graph_input_embd() = default; + virtual ~llm_graph_input_embd() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * tokens = nullptr; // I32 [n_batch] + ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061] +}; + +class llm_graph_input_pos : public llm_graph_input_i { +public: + llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {} + virtual ~llm_graph_input_pos() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * pos = nullptr; // I32 [n_batch] + + const int64_t n_pos_per_token = 1; +}; + +// temperature tuning, used by llama4 +class llm_graph_input_attn_temp : public llm_graph_input_i { +public: + llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale) + : n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {} + virtual ~llm_graph_input_attn_temp() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * attn_scale = nullptr; // F32 [n_batch] + + const int64_t n_pos_per_token = 1; + + const uint32_t n_attn_temp_floor_scale; + const float f_attn_temp_scale; +}; + +class llm_graph_input_pos_bucket : public llm_graph_input_i { +public: + llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {} + virtual ~llm_graph_input_pos_bucket() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch] + + const llama_hparams & hparams; +}; + +class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { +public: + llm_graph_input_pos_bucket_kv( + const llama_hparams & hparams, + const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {} + virtual ~llm_graph_input_pos_bucket_kv() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] + + const llama_hparams & hparams; + const llama_kv_cache_unified * kv_self; +}; + +class llm_graph_input_out_ids : public llm_graph_input_i { +public: + llm_graph_input_out_ids( + const llama_hparams & hparams, + const llama_cparams & cparams, + int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {} + virtual ~llm_graph_input_out_ids() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * out_ids; // I32 [n_outputs] + + const llama_hparams & hparams; + const llama_cparams & cparams; + + const int32_t n_outputs; +}; + +class llm_graph_input_mean : public llm_graph_input_i { +public: + llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {} + virtual ~llm_graph_input_mean() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * mean; // F32 [n_batch, n_batch] + + const llama_cparams & cparams; +}; + +class llm_graph_input_cls : public llm_graph_input_i { +public: + llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {} + virtual ~llm_graph_input_cls() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * cls; // I32 [n_batch] + + const llama_cparams & cparams; +}; + +class llm_graph_input_s_copy : public llm_graph_input_i { +public: + llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_s_copy() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * s_copy; // I32 [kv_size] + + const llama_kv_cache_unified * kv_self; +}; + +class llm_graph_input_s_mask : public llm_graph_input_i { +public: + llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_s_mask() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * s_mask; // F32 [1, n_kv] + + const llama_kv_cache_unified * kv_self; +}; + +class llm_graph_input_cross_embd : public llm_graph_input_i { +public: + llm_graph_input_cross_embd( + const llama_cross * cross) : cross(cross) {} + virtual ~llm_graph_input_cross_embd() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc] + + const llama_cross * cross; +}; + +class llm_graph_input_attn_no_cache : public llm_graph_input_i { +public: + llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) : + hparams(hparams), + cparams(cparams) { + } + ~llm_graph_input_attn_no_cache() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } + + ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] + ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] + + const llama_hparams & hparams; + const llama_cparams & cparams; +}; + +class llm_graph_input_attn_kv_unified : public llm_graph_input_i { +public: + llm_graph_input_attn_kv_unified( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_unified * kv_self) : + hparams(hparams), + cparams(cparams), + kv_self(kv_self) { + } + ~llm_graph_input_attn_kv_unified() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] + + const llama_hparams & hparams; + const llama_cparams & cparams; + + const llama_kv_cache_unified * kv_self; +}; + +class llm_graph_input_attn_cross : public llm_graph_input_i { +public: + llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} + ~llm_graph_input_attn_cross() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } + + ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch] + + const llama_cross * cross = nullptr; +}; + +class llm_graph_input_cross_attn_state : public llm_graph_input_i { +public: + llm_graph_input_cross_attn_state() = default; + virtual ~llm_graph_input_cross_attn_state() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061] +}; + +// +// llm_graph_result +// + +// these objects deliver the result from the graph build process back to the llama_context +// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their +// specific data, by calling the set_inputs() method +// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc. +// these are used by the llama_context to extact the relevant data, based on the compute parameters + +class llm_graph_result_i { +public: + virtual ~llm_graph_result_i() = default; + + virtual ggml_tensor * get_logits() = 0; + virtual ggml_tensor * get_embd() = 0; + virtual ggml_tensor * get_embd_pooled() = 0; + + virtual void set_inputs(const llama_ubatch * ubatch) = 0; +}; + +using llm_graph_result_ptr = std::unique_ptr; + + +class llm_graph_result : public llm_graph_result_i { +public: + virtual ~llm_graph_result() = default; + + ggml_tensor * get_logits() override { return t_logits; } + ggml_tensor * get_embd() override { return t_embd; } + ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } + + void set_inputs(const llama_ubatch * ubatch) override { + for (auto & input : inputs) { + input->set_input(ubatch); + } + } + + llm_graph_input_i * add_input(llm_graph_input_ptr input) { + inputs.emplace_back(std::move(input)); + return inputs.back().get(); + } + + // important graph nodes + ggml_tensor * t_logits = nullptr; + ggml_tensor * t_embd = nullptr; + ggml_tensor * t_embd_pooled = nullptr; + + std::vector inputs; +}; + +// +// llm_graph_context +// + +// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) +using llm_graph_cb = std::function; + +struct llm_graph_params { + ggml_context * ctx; + + const llm_arch arch; + + const llama_hparams & hparams; + const llama_cparams & cparams; + const llama_ubatch & ubatch; + + ggml_backend_sched * sched; + ggml_backend * backend_cpu; + + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_i * memory; + const llama_cross * cross; + + int32_t n_outputs; + + const llm_graph_cb & cb; +}; + +struct llm_graph_context { + const llm_arch arch; + + const llama_hparams & hparams; + const llama_cparams & cparams; + const llama_ubatch & ubatch; + + const int64_t n_embd; + const int64_t n_layer; + const int64_t n_rot; + const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) + const int64_t n_ctx_per_seq; + const int64_t n_head; + const int64_t n_head_kv; + const int64_t n_embd_head_k; + const int64_t n_embd_k_gqa; + const int64_t n_embd_head_v; + const int64_t n_embd_v_gqa; + const int64_t n_expert; + const int64_t n_expert_used; + + const float freq_base; + const float freq_scale; + const float ext_factor; + const float attn_factor; + const float beta_fast; + const float beta_slow; + const float norm_eps; + const float norm_rms_eps; + + const int32_t n_tokens; + const int32_t n_outputs; + const int32_t n_ctx_orig; // yarn + + const enum llama_pooling_type pooling_type; + const enum llama_rope_type rope_type; + + ggml_context * ctx0 = nullptr; + + ggml_backend_sched * sched; + + ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? + + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_i * memory; + const llama_cross * cross; + + const llm_graph_cb & cb_func; + + std::unique_ptr res; + + llm_graph_context(const llm_graph_params & params); + + int64_t n_pos_per_token() const; + + void cb(ggml_tensor * cur, const char * name, int il) const; + + // + // common + // + + ggml_tensor * build_cvec( + ggml_tensor * cur, + int il) const; + + // do mat_mul, while optionally apply lora + ggml_tensor * build_lora_mm( + ggml_tensor * w, + ggml_tensor * cur) const; + + // do mat_mul_id, while optionally apply lora + ggml_tensor * build_lora_mm_id( + ggml_tensor * w, // ggml_tensor * as + ggml_tensor * cur, // ggml_tensor * b + ggml_tensor * ids) const; + + ggml_tensor * build_norm( + ggml_tensor * cur, + ggml_tensor * mw, + ggml_tensor * mb, + llm_norm_type type, + int il) const; + + ggml_tensor * build_ffn( + ggml_tensor * cur, + ggml_tensor * up, + ggml_tensor * up_b, + ggml_tensor * up_s, + ggml_tensor * gate, + ggml_tensor * gate_b, + ggml_tensor * gate_s, + ggml_tensor * down, + ggml_tensor * down_b, + ggml_tensor * down_s, + ggml_tensor * act_scales, + llm_ffn_op_type type_op, + llm_ffn_gate_type type_gate, + int il) const; + + ggml_tensor * build_moe_ffn( + ggml_tensor * cur, + ggml_tensor * gate_inp, + ggml_tensor * up_exps, + ggml_tensor * gate_exps, + ggml_tensor * down_exps, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + bool scale_w, + float w_scale, + llama_expert_gating_func_type gating_op, + int il) const; + + // + // inputs + // + + ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; + ggml_tensor * build_inp_pos() const; + ggml_tensor * build_inp_attn_scale() const; + ggml_tensor * build_inp_out_ids() const; + ggml_tensor * build_inp_mean() const; + ggml_tensor * build_inp_cls() const; + ggml_tensor * build_inp_s_copy() const; + ggml_tensor * build_inp_s_mask() const; + ggml_tensor * build_inp_cross_attn_state() const; + + ggml_tensor * build_inp_cross_embd() const; + ggml_tensor * build_inp_pos_bucket_enc() const; + ggml_tensor * build_inp_pos_bucket_dec() const; + ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const; + + // + // attention + // + + ggml_tensor * build_attn_mha( + ggml_cgraph * gf, + ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q] + ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k] + ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) + ggml_tensor * kq_b, + ggml_tensor * kq_mask, + bool v_trans, + float kq_scale) const; + + llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_no_cache * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + float kq_scale, + int il) const; + + llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_unified * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + float kq_scale, + int il) const; + + llm_graph_input_attn_cross * build_attn_inp_cross() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_cross * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + float kq_scale, + int il) const; + + // + // recurrent + // + + ggml_tensor * build_copy_mask_state( + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + int32_t n_state, + int32_t n_seqs) const; + + ggml_tensor * build_rwkv_token_shift_load( + ggml_cgraph * gf, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il) const; + + ggml_tensor * build_rwkv_token_shift_store( + ggml_tensor * token_shift, + const llama_ubatch & ubatch, + int il) const; + + // + // pooling + // + + void build_pooling( + ggml_cgraph * gf, + ggml_tensor * cls, + ggml_tensor * cls_b, + ggml_tensor * cls_out, + ggml_tensor * cls_out_b) const; +}; diff --git a/llama/llama.cpp/src/llama-hparams.cpp b/llama/llama.cpp/src/llama-hparams.cpp index 0b8410281..6a02de036 100644 --- a/llama/llama.cpp/src/llama-hparams.cpp +++ b/llama/llama.cpp/src/llama-hparams.cpp @@ -2,8 +2,6 @@ #include "ggml.h" -#include - uint32_t llama_hparams::n_head(uint32_t il) const { if (il < n_layer) { return n_head_arr[il]; @@ -80,6 +78,14 @@ bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const { GGML_ABORT("fatal error"); } +bool llama_hparams::is_swa(uint32_t il) const { + if (il < n_layer) { + return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1); + } + + GGML_ABORT("fatal error"); +} + bool llama_hparams::cross_attention_layers(uint32_t il) const { return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); -} \ No newline at end of file +} diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index 053830461..4567a0e90 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -2,6 +2,8 @@ #include "llama.h" +#include + #include // bump if necessary @@ -36,6 +38,7 @@ struct llama_hparams { uint32_t n_layer; uint32_t n_rot; uint32_t n_swa = 0; // sliding window attention (SWA) + uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; @@ -79,10 +82,16 @@ struct llama_hparams { uint32_t time_decay_extra_dim = 0; uint32_t wkv_head_size = 0; uint32_t token_shift_count = 2; + uint32_t n_lora_decay = 0; + uint32_t n_lora_iclr = 0; + uint32_t n_lora_value_res_mix = 0; + uint32_t n_lora_gate = 0; float rope_attn_factor = 1.0f; float rope_freq_base_train; + float rope_freq_base_train_swa; float rope_freq_scale_train; + float rope_freq_scale_train_swa; uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul; @@ -109,6 +118,14 @@ struct llama_hparams { bool use_alibi = false; bool attn_soft_cap = false; + uint32_t n_moe_layer_step = 0; + bool use_kq_norm = true; + uint32_t n_attn_chunk = 0; + // values below seems to be fixed on llama4 + uint32_t n_no_rope_layer_step = 4; + uint32_t n_attn_temp_floor_scale = 8192; + float f_attn_temp_scale = 0.1; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; @@ -143,6 +160,8 @@ struct llama_hparams { // cross attention layers bool cross_attention_layers(uint32_t il) const; + + bool is_swa(uint32_t il) const; }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); diff --git a/llama/llama.cpp/src/llama-io.cpp b/llama/llama.cpp/src/llama-io.cpp new file mode 100644 index 000000000..7ad70d163 --- /dev/null +++ b/llama/llama.cpp/src/llama-io.cpp @@ -0,0 +1,15 @@ +#include "llama-io.h" + +void llama_io_write_i::write_string(const std::string & str) { + uint32_t str_size = str.size(); + + write(&str_size, sizeof(str_size)); + write(str.data(), str_size); +} + +void llama_io_read_i::read_string(std::string & str) { + uint32_t str_size; + read_to(&str_size, sizeof(str_size)); + + str.assign((const char *) read(str_size), str_size); +} diff --git a/llama/llama.cpp/src/llama-io.h b/llama/llama.cpp/src/llama-io.h new file mode 100644 index 000000000..ce9216b83 --- /dev/null +++ b/llama/llama.cpp/src/llama-io.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include + +struct ggml_tensor; + +class llama_io_write_i { +public: + llama_io_write_i() = default; + virtual ~llama_io_write_i() = default; + + virtual void write(const void * src, size_t size) = 0; + virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0; + + // bytes written so far + virtual size_t n_bytes() = 0; + + void write_string(const std::string & str); +}; + +class llama_io_read_i { +public: + llama_io_read_i() = default; + virtual ~llama_io_read_i() = default; + + virtual const uint8_t * read(size_t size) = 0; + virtual void read_to(void * dst, size_t size) = 0; + + // bytes read so far + virtual size_t n_bytes() = 0; + + void read_string(std::string & str); +}; diff --git a/llama/llama.cpp/src/llama-kv-cache.cpp b/llama/llama.cpp/src/llama-kv-cache.cpp index b541c5a37..5c941e7c4 100644 --- a/llama/llama.cpp/src/llama-kv-cache.cpp +++ b/llama/llama.cpp/src/llama-kv-cache.cpp @@ -6,86 +6,90 @@ #include "llama-model.h" #include +#include #include #include +#include -static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; - -uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; +llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { } -bool llama_kv_cache_init( - struct llama_kv_cache & cache, - const llama_model & model, - const llama_cparams & cparams, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size, - bool offload) { - const struct llama_hparams & hparams = model.hparams; - +bool llama_kv_cache_unified::init( + const llama_model & model, + const llama_cparams & cparams, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + bool offload) { const int32_t n_layer = hparams.n_layer; - cache.has_shift = false; + has_shift = false; - cache.recurrent = llama_model_is_recurrent(&model); - cache.v_trans = !cache.recurrent && !cparams.flash_attn; - cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + recurrent = llama_model_is_recurrent(&model); + v_trans = !recurrent && !cparams.flash_attn; + can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", - __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift); + __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); - cache.head = 0; - cache.size = kv_size; - cache.used = 0; + head = 0; + size = kv_size; + used = 0; - cache.type_k = type_k; - cache.type_v = type_v; + this->type_k = type_k; + this->type_v = type_v; - cache.cells.clear(); - cache.cells.resize(kv_size); + cells.clear(); + cells.resize(kv_size); // create a context for each buffer type std::map ctx_map; auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { - struct ggml_init_params params = { + ggml_init_params params = { /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; + ggml_context * ctx = ggml_init(params); if (!ctx) { return nullptr; } + ctx_map[buft] = ctx; - cache.ctxs.emplace_back(ctx); + ctxs.emplace_back(ctx); + return ctx; } + return it->second; }; - cache.k_l.reserve(n_layer); - cache.v_l.reserve(n_layer); + k_l.reserve(n_layer); + v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa); + const char * dev_name = "CPU"; ggml_backend_buffer_type_t buft; if (offload) { auto * dev = model.dev_layer(i); buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); } else { buft = ggml_backend_cpu_buffer_type(); } - ggml_context * ctx = ctx_for_buft(buft); + LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__, + i, n_embd_k_gqa, n_embd_v_gqa, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); return false; @@ -101,11 +105,10 @@ bool llama_kv_cache_init( k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); } - ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); - cache.k_l.push_back(k); - cache.v_l.push_back(v); + k_l.push_back(k); + v_l.push_back(v); } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -120,280 +123,80 @@ bool llama_kv_cache_init( } ggml_backend_buffer_clear(buf, 0); LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - cache.bufs.emplace_back(buf); + bufs.emplace_back(buf); } return true; } -struct llama_kv_cache_slot_info llama_kv_cache_find_slot( - struct llama_kv_cache & cache, - const struct llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; - const uint32_t n_seq_tokens = ubatch.n_seq_tokens; +int32_t llama_kv_cache_unified::get_n_tokens() const { + int32_t result = 0; - if (cache.recurrent) { - // For recurrent state architectures (like Mamba or RWKV), - // each cache cell can store the state for a whole sequence. - // A slot should be always be contiguous. - - // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(ubatch.equal_seqs); - - int32_t min = cache.size - 1; - int32_t max = 0; - - // everything should fit if all seq_ids are smaller than the max - for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = ubatch.n_seq_id[s]; - for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - - if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { - // too big seq_id - // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); - return llama_kv_cache_slot_info_failed; - } - if (j > 0) { - llama_kv_cell & seq = cache.cells[seq_id]; - if (seq.tail >= 0) { - llama_kv_cell & cell = cache.cells[seq.tail]; - // clear cells from seq_ids that become shared - // (should not normally happen, but let's handle it anyway) - cell.seq_id.erase(seq_id); - seq.tail = -1; - if (cell.seq_id.empty()) { - cell.pos = -1; - cell.src = -1; - cache.used -= 1; - } - } - } - } - } - -#ifndef NDEBUG - { - std::vector tails_verif; - tails_verif.assign(cache.size, -1); - for (uint32_t i = 0; i < cache.size; ++i) { - llama_kv_cell & cell = cache.cells[i]; - for (llama_seq_id seq_id : cell.seq_id) { - if (tails_verif[seq_id] != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); - } - tails_verif[seq_id] = i; - } - } - for (uint32_t i = 0; i < cache.size; ++i) { - if (tails_verif[i] != cache.cells[i].tail) { - LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]); - } - } - } -#endif - - // find next empty cell - uint32_t next_empty_cell = cache.head; - - for (uint32_t i = 0; i < cache.size; ++i) { - if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } - llama_kv_cell & cell = cache.cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - - // find usable cell range - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - llama_kv_cell & seq_meta = cache.cells[seq_id]; - bool has_cell = false; - if (seq_meta.tail >= 0) { - llama_kv_cell & cell = cache.cells[seq_meta.tail]; - GGML_ASSERT(cell.has_seq_id(seq_id)); - // does this seq_id "own" the cell? - if (cell.seq_id.size() == 1) { has_cell = true; } - } - if (!has_cell) { - llama_kv_cell & empty_cell = cache.cells[next_empty_cell]; - GGML_ASSERT(empty_cell.is_empty()); - // copy old tail into the empty cell - if (seq_meta.tail >= 0) { - llama_kv_cell & orig_cell = cache.cells[seq_meta.tail]; - empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); - empty_cell.seq_id.insert(seq_id); // will be overwritten - } - seq_meta.tail = next_empty_cell; - // find next empty cell - if (s + 1 < n_seqs) { - next_empty_cell += 1; - for (uint32_t i = 0; i < cache.size; ++i) { - if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } - llama_kv_cell & cell = cache.cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - } - } - if (min > seq_meta.tail) { min = seq_meta.tail; } - if (max < seq_meta.tail) { max = seq_meta.tail; } - } - - // gather and re-order - for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail; - if (dst_id != src_id) { - llama_kv_cell & dst_cell = cache.cells[dst_id]; - llama_kv_cell & src_cell = cache.cells[src_id]; - - std::swap(dst_cell.pos, src_cell.pos); - std::swap(dst_cell.src, src_cell.src); - std::swap(dst_cell.seq_id, src_cell.seq_id); - - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cache.cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cache.cells[seq_id].tail = dst_id; - } - } - } - - // update the pos of the used seqs - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; - llama_kv_cell & cell = cache.cells[cell_id]; - - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); - } - cell.pos = last_pos; - cell.seq_id.clear(); - for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - cell.seq_id.insert(seq_id); - cache.cells[seq_id].tail = cell_id; - } - } - - // allow getting the range of used cells, from head to head + n - cache.head = min; - cache.n = max - min + 1; - cache.used = std::count_if(cache.cells.begin(), cache.cells.end(), - [](const llama_kv_cell& cell){ return !cell.is_empty(); }); - - // sanity check - return llama_kv_cache_slot_info(cache.n >= n_seqs); - } - // otherwise, one cell per token. - - if (n_tokens > cache.size) { - LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); - return llama_kv_cache_slot_info_failed; + for (uint32_t i = 0; i < size; i++) { + result += cells[i].seq_id.size(); } - uint32_t n_tested = 0; - - while (true) { - if (cache.head + n_tokens > cache.size) { - n_tested += cache.size - cache.head; - cache.head = 0; - continue; - } - - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cache.cells[cache.head + i].pos >= 0) { - found = false; - cache.head += i + 1; - n_tested += i + 1; - break; - } - } - - if (found) { - break; - } - - if (n_tested >= cache.size) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return llama_kv_cache_slot_info_failed; - } - } - - for (uint32_t s = 0; s < n_seqs; s++) { - for (uint32_t i = 0; i < n_seq_tokens; ++i) { - uint32_t k = s*n_seq_tokens + i; - cache.cells[cache.head + k].pos = ubatch.pos[k]; - - for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { - cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]); - } - } - } - - cache.used += n_tokens; - - return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens); + return result; } -uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { - for (uint32_t i = cache.size; i > 0; --i) { - const llama_kv_cell & cell = cache.cells[i - 1]; - - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } - } - - return 0; +int32_t llama_kv_cache_unified::get_used_cells() const { + return used; } -void llama_kv_cache_clear(struct llama_kv_cache & cache) { - for (int32_t i = 0; i < (int32_t) cache.size; ++i) { - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - cache.cells[i].src = -1; - cache.cells[i].tail = -1; +size_t llama_kv_cache_unified::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); } - cache.head = 0; - cache.used = 0; - for (auto & buf : cache.bufs) { + return size; +} + +llama_pos llama_kv_cache_unified::pos_max() const { + llama_pos pos_max = -1; + for (const auto & cell : cells) { + pos_max = std::max(pos_max, cell.pos); + } + + return pos_max; +} + +void llama_kv_cache_unified::clear() { + for (int32_t i = 0; i < (int32_t) size; ++i) { + cells[i].pos = -1; + cells[i].seq_id.clear(); + cells[i].src = -1; + cells[i].tail = -1; + } + head = 0; + used = 0; + + for (auto & buf : bufs) { ggml_backend_buffer_clear(buf.get(), 0); } } -bool llama_kv_cache_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - uint32_t new_head = cache.size; +bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = size; - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } // models like Mamba or RWKV can't have a state partially erased - if (cache.recurrent) { - if (seq_id >= (int64_t) cache.size) { + if (recurrent) { + if (seq_id >= (int64_t) size) { // could be fatal return false; } if (0 <= seq_id) { - int32_t & tail_id = cache.cells[seq_id].tail; + int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const llama_kv_cell & cell = cache.cells[tail_id]; + const llama_kv_cell & cell = cells[tail_id]; // partial intersection is invalid if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { return false; @@ -409,50 +212,63 @@ bool llama_kv_cache_seq_rm( return false; } } + + return true; } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].pos >= p0 && cells[i].pos < p1) { if (seq_id < 0) { - cache.cells[i].seq_id.clear(); - } else if (cache.cells[i].has_seq_id(seq_id)) { - cache.cells[i].seq_id.erase(seq_id); + cells[i].seq_id.clear(); + } else if (cells[i].has_seq_id(seq_id)) { + cells[i].seq_id.erase(seq_id); } else { continue; } - if (cache.cells[i].is_empty()) { + if (cells[i].is_empty()) { // keep count of the number of used cells - if (cache.cells[i].pos >= 0) cache.used--; + if (cells[i].pos >= 0) { + used--; + } - cache.cells[i].pos = -1; - cache.cells[i].src = -1; - if (new_head == cache.size) new_head = i; + cells[i].pos = -1; + cells[i].src = -1; + + if (new_head == size) { + new_head = i; + } } } } // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; + if (new_head != size && new_head < head) { + head = new_head; + } return true; } -void llama_kv_cache_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); +void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } - if (cache.recurrent) { - if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { - llama_kv_cell & tail_src = cache.cells[seq_id_src]; - llama_kv_cell & tail_dst = cache.cells[seq_id_dst]; + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if (recurrent) { + if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { + llama_kv_cell & tail_src = cells[seq_id_src]; + llama_kv_cell & tail_dst = cells[seq_id_dst]; if (tail_dst.tail >= 0) { // clear destination seq_id if it wasn't empty - llama_kv_cell & cell_dst = cache.cells[tail_dst.tail]; + llama_kv_cell & cell_dst = cells[tail_dst.tail]; cell_dst.seq_id.erase(seq_id_dst); tail_dst.tail = -1; @@ -460,11 +276,11 @@ void llama_kv_cache_seq_cp( cell_dst.pos = -1; cell_dst.delta = -1; cell_dst.src = -1; - cache.used -= 1; + used -= 1; } } if (tail_src.tail >= 0) { - llama_kv_cell & cell_src = cache.cells[tail_src.tail]; + llama_kv_cell & cell_src = cells[tail_src.tail]; cell_src.seq_id.insert(seq_id_dst); tail_dst.tail = tail_src.tail; @@ -473,59 +289,75 @@ void llama_kv_cache_seq_cp( return; } - // otherwise, this is the KV cache of a Transformer-like model - cache.head = 0; + // otherwise, this is the KV of a Transformer-like model + head = 0; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].seq_id.insert(seq_id_dst); + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) { + cells[i].seq_id.insert(seq_id_dst); } } } -void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { - uint32_t new_head = cache.size; +void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = size; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.recurrent && (llama_seq_id) i != seq_id) { - cache.cells[i].tail = -1; + for (uint32_t i = 0; i < size; ++i) { + if (recurrent && (llama_seq_id) i != seq_id) { + cells[i].tail = -1; } - if (!cache.cells[i].has_seq_id(seq_id)) { - if (cache.cells[i].pos >= 0) cache.used--; - cache.cells[i].pos = -1; - cache.cells[i].src = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) new_head = i; + + if (!cells[i].has_seq_id(seq_id)) { + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + cells[i].seq_id.clear(); + + if (new_head == size){ + new_head = i; + } } else { - cache.cells[i].seq_id.clear(); - cache.cells[i].seq_id.insert(seq_id); + cells[i].seq_id.clear(); + cells[i].seq_id.insert(seq_id); } } // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; + if (new_head != size && new_head < head) { + head = new_head; + } } -void llama_kv_cache_seq_add( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - uint32_t new_head = cache.size; +void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (delta == 0) { + return; + } - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) return; + uint32_t new_head = size; - if (cache.recurrent) { + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the + if (p0 == p1) { + return; + } + + if (recurrent) { // for Mamba-like or RWKV models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - const int32_t tail_id = cache.cells[seq_id].tail; + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; if (tail_id >= 0) { - llama_kv_cell & cell = cache.cells[tail_id]; + llama_kv_cell & cell = cells[tail_id]; if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { cell.pos += delta; } @@ -534,19 +366,19 @@ void llama_kv_cache_seq_add( return; } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; - cache.cells[i].pos += delta; - cache.cells[i].delta += delta; + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { + has_shift = true; + cells[i].pos += delta; + cells[i].delta += delta; - if (cache.cells[i].pos < 0) { - if (!cache.cells[i].is_empty()) { - cache.used--; + if (cells[i].pos < 0) { + if (!cells[i].is_empty()) { + used--; } - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) { + cells[i].pos = -1; + cells[i].seq_id.clear(); + if (new_head == size) { new_head = i; } } @@ -555,93 +387,890 @@ void llama_kv_cache_seq_add( // If we freed up a slot, set head to it so searching can start there. // Otherwise we just start the next search from the beginning. - cache.head = new_head != cache.size ? new_head : 0; + head = new_head != size ? new_head : 0; } -void llama_kv_cache_seq_div( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) return; +void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } - if (cache.recurrent) { + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + if (recurrent) { // for Mamba-like or RWKV models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - const int32_t tail_id = cache.cells[seq_id].tail; + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; if (tail_id >= 0) { - llama_kv_cell & cell = cache.cells[tail_id]; + llama_kv_cell & cell = cells[tail_id]; if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { cell.pos /= d; } } } + return; } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { + has_shift = true; { - llama_pos p_old = cache.cells[i].pos; - cache.cells[i].pos /= d; - cache.cells[i].delta += cache.cells[i].pos - p_old; + llama_pos p_old = cells[i].pos; + cells[i].pos /= d; + cells[i].delta += cells[i].pos - p_old; } } } } -llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { +llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { llama_pos result = 0; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id)) { - result = std::max(result, cache.cells[i].pos); + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::max(result, cells[i].pos); } } return result; } -void llama_kv_cache_defrag(struct llama_kv_cache & cache) { - if (!cache.recurrent) { - cache.do_defrag = true; +void llama_kv_cache_unified::defrag() { + if (!recurrent) { + do_defrag = true; } } -int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv) { - int result = 0; - - for (uint32_t i = 0; i < kv.size; i++) { - result += kv.cells[i].seq_id.size(); +void llama_kv_cache_unified::restore() { + if (pending.ranges.empty()) { + return; } - return result; + // TODO: tmp - move to llama_kv_cache_recurrent + if (recurrent) { + seq_rm(-1, -1, -1); + return; + } + + uint32_t new_head = size; + + for (auto & range : pending.ranges) { + for (uint32_t i = range.c0; i < range.c1; ++i) { + cells[i].seq_id.clear(); + + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + } + + new_head = std::min(new_head, range.c0); + } + + if (new_head != size && new_head < head) { + head = new_head; + } } -int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv) { - return kv.used; +void llama_kv_cache_unified::commit() { + // TODO: tmp - move to llama_kv_cache_recurrent + if (recurrent) { + return; + } + + if (pending.ranges.empty()) { + LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/12695"); + return; + } + + pending.ranges.clear(); } -bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv) { - return kv.can_shift; +bool llama_kv_cache_unified::get_can_shift() const { + return can_shift; +} + +bool llama_kv_cache_unified::find_slot( + const llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*ubatch.n_tokens) { + head = 0; + } + + if (recurrent) { + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); + return false; + } + if (j > 0) { + llama_kv_cell & seq = cells[seq_id]; + if (seq.tail >= 0) { + llama_kv_cell & cell = cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(size, -1); + for (uint32_t i = 0; i < size; ++i) { + llama_kv_cell & cell = cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < size; ++i) { + if (tails_verif[i] != cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = head; + + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + llama_kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + llama_kv_cell & seq_meta = cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + llama_kv_cell & cell = cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + llama_kv_cell & empty_cell = cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + llama_kv_cell & orig_cell = cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + llama_kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + llama_kv_cell & dst_cell = cells[dst_id]; + llama_kv_cell & src_cell = cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + llama_kv_cell & cell = cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cells[seq_id].tail = cell_id; + } + } + + // allow getting the range of used cells, from head to head + n + head = min; + n = max - min + 1; + used = std::count_if(cells.begin(), cells.end(), + [](const llama_kv_cell& cell){ return !cell.is_empty(); }); + + // sanity check + return n >= n_seqs; + } + + // otherwise, one cell per token. + + if (n_tokens > size) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); + return false; + } + + uint32_t n_tested = 0; + + while (true) { + if (head + n_tokens > size) { + n_tested += size - head; + head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cells[head + i].pos >= 0) { + found = false; + head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t s = 0; s < n_seqs; s++) { + for (uint32_t i = 0; i < n_seq_tokens; ++i) { + uint32_t k = s*n_seq_tokens + i; + cells[head + k].pos = ubatch.pos[k]; + + for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { + cells[head + k].seq_id.insert(ubatch.seq_id[s][j]); + } + } + } + + used += n_tokens; + + pending.ranges.push_back({head, head + n_tokens}); + + return true; +} + +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} + +uint32_t llama_kv_cache_unified::cell_max() const { + for (uint32_t i = size; i > 0; --i) { + const llama_kv_cell & cell = cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + +size_t llama_kv_cache_unified::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & k : k_l) { + size_k_bytes += ggml_nbytes(k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_unified::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & v : v_l) { + size_v_bytes += ggml_nbytes(v); + } + + return size_v_bytes; +} + +bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { + const uint32_t n_layer = hparams.n_layer; + + const uint32_t n_kv = cell_max(); + const uint32_t n_used = used; + + assert(n_used <= n_kv); + + defrag_info.moves.clear(); + + // determine which KV cells to move where + // + // cell i moves to ids[i] + // + // if ids[i] == i || ids[i] == n_kv, then cell i is not moved + // + std::vector ids(n_kv, n_kv); + + for (uint32_t i0 = 0; i0 < n_used; ++i0) { + const auto & cell0 = cells[i0]; + + if (!cell0.is_empty()) { + ids[i0] = i0; + + continue; + } + + // found a hole - fill it with data from the end of the cache + + uint32_t nh = 1; + + // determine the size of the hole + while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { + nh++; + } + + uint32_t nf = 0; + uint32_t is = n_kv - 1; + + // starting from the end, find nh non-empty cells + for (; is > i0; --is) { + const auto & cell1 = cells[is]; + + if (cell1.is_empty() || ids[is] != n_kv) { + continue; + } + + // non-empty cell which is not yet moved + nf++; + + if (nf == nh) { + break; + } + } + + // this can only happen if `n_used` is not accurate, which would be a bug + GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); + + nf = 0; + + uint32_t i1 = is; + + // are we moving a continuous block of memory? + bool cont = false; + + // go back and move the nf cells to the hole + for (; i1 < n_kv; ++i1) { + auto & cell1 = cells[i1]; + + if (cell1.is_empty() || ids[i1] != n_kv) { + cont = false; + continue; + } + + // this cell goes to (i0 + nf) + ids[i1] = i0 + nf; + + // move the cell meta data + cells[i0 + nf] = cell1; + + // clear the old cell and move the head there + cell1 = llama_kv_cell(); + head = n_used; + + if (!cont) { + defrag_info.moves.push_back({i1, i0 + nf, 1}); + cont = true; + } else { + defrag_info.moves.back().len++; + } + + nf++; + + if (nf == nh) { + break; + } + } + + //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); + + i0 += nh - 1; + } + + if (defrag_info.moves.size() == 0) { + return false; + } + + // LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves); + + return true; +} + +void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = size; + for (uint32_t i = 0; i < size; ++i) { + const auto & cell = cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = size; + } + } + } + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } + } +} + +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = this->v_trans ? 1 : 0; + const uint32_t n_layer = hparams.n_layer; + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)k_l[il]->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(v_l[il]->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(v_l[il], src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + if (!find_slot(batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + commit(); + + // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head + cell_count <= size); + GGML_ASSERT(cells[head].pos == batch.pos[0]); + GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); + GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(); + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_kv_cell & cell = cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + // TODO: llama_kv_cache_unified should have a notion of max sequences + //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0) { + //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + return false; + } + + cell.seq_id.insert(seq_id); + + if (recurrent) { + int32_t & tail = cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + } + + head = 0; + used = cell_count; + } + + if (recurrent) { + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = head + i; + // make sure the recurrent states will keep their restored state + cells[cell_id].src = cell_id; + } + } + + return true; +} + +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + return false; + } + if (v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * size) * v_size_el; + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; } // // kv cache view // -struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max) { - struct llama_kv_cache_view result = { +llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) { + llama_kv_cache_view result = { /*.n_cells = */ 0, /*.n_seq_max = */ n_seq_max, /*.token_count = */ 0, - /*.used_cells = */ llama_get_kv_cache_used_cells(kv), + /*.used_cells = */ kv.get_used_cells(), /*.max_contiguous = */ 0, /*.max_contiguous_idx = */ -1, /*.cells = */ nullptr, @@ -651,7 +1280,7 @@ struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache return result; } -void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { +void llama_kv_cache_view_free(llama_kv_cache_view * view) { if (view->cells != nullptr) { free(view->cells); view->cells = nullptr; @@ -662,18 +1291,25 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { } } -void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) { - if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) { - view->n_cells = int32_t(kv.size); - void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); +void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) { + // TODO: rework this in the future, for now quick hack + const llama_kv_cache_unified * kvu = dynamic_cast(kv); + if (kvu == nullptr) { + LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__); + return; + } + + if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) { + view->n_cells = int32_t(kvu->size); + void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); - view->cells = (struct llama_kv_cache_view_cell *)p; + view->cells = (llama_kv_cache_view_cell *)p; p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences"); view->cells_sequences = (llama_seq_id *)p; } - const std::vector & kv_cells = kv.cells; + const std::vector & kv_cells = kvu->cells; llama_kv_cache_view_cell * c_curr = view->cells; llama_seq_id * cs_curr = view->cells_sequences; int32_t used_cells = 0; @@ -682,7 +1318,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct uint32_t max_contig = 0; int32_t max_contig_idx = -1; - for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) { + for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) { const size_t curr_size = kv_cells[i].seq_id.size(); token_count += curr_size; c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; @@ -720,8 +1356,8 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct view->max_contiguous_idx = max_contig_idx; view->token_count = token_count; view->used_cells = used_cells; - if (uint32_t(used_cells) != kv.used) { + if (uint32_t(used_cells) != kvu->used) { LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", - __func__, kv.used, used_cells); + __func__, kvu->used, used_cells); } } diff --git a/llama/llama.cpp/src/llama-kv-cache.h b/llama/llama.cpp/src/llama-kv-cache.h index 1ed688e3b..25cbcb562 100644 --- a/llama/llama.cpp/src/llama-kv-cache.h +++ b/llama/llama.cpp/src/llama-kv-cache.h @@ -1,15 +1,58 @@ #pragma once #include "llama.h" +#include "llama-io.h" +#include "llama-memory.h" #include "ggml-cpp.h" +#include #include #include +struct llama_cparams; +struct llama_hparams; +struct llama_ubatch; + +struct llama_kv_cache : public llama_memory_i { + using llama_memory_i::llama_memory_i; + + virtual void restore() = 0; // call if batch processing fails - restores the cache state + virtual void commit() = 0; // call after successful batch processing - clears any pending state + + virtual int32_t get_n_tokens() const = 0; + virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache + + virtual bool get_can_shift() const = 0; + + bool get_can_edit() const override { return get_can_shift(); } +}; + +struct llama_kv_cache_guard { + llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} + + ~llama_kv_cache_guard() { + kv->restore(); + } + + void commit() { + kv->commit(); + } + +private: + llama_kv_cache * kv; +}; + +// block of KV slots to move when defragging +struct llama_kv_defrag_move { + uint32_t src; + uint32_t dst; + uint32_t len; +}; + struct llama_kv_cell { llama_pos pos = -1; - llama_pos delta = 0; + llama_pos delta = 0; int32_t src = -1; // used by recurrent state models to copy states int32_t tail = -1; @@ -29,10 +72,107 @@ struct llama_kv_cell { }; // ring-buffer of cached KV data -struct llama_kv_cache { +// TODO: pimpl +// TODO: add notion of max sequences +class llama_kv_cache_unified : public llama_kv_cache { +public: + // can be used to query data from the model if needed + struct callbacks { + std::function get_rope_factors; + }; + + llama_kv_cache_unified( + const llama_hparams & hparams, + callbacks cbs); + + virtual ~llama_kv_cache_unified() = default; + + // TODO: become constructor + bool init( + const llama_model & model, // TODO: do not reference the model + const llama_cparams & cparams, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + bool offload); + + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; + + size_t total_size() const; + + // TODO: better data structures to reduce the cost of this operation + llama_pos pos_max() const; + + void clear() override; + void defrag() override; + + virtual void restore() override; + virtual void commit() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + bool get_can_shift() const override; + + // find an empty slot of size "n_tokens" in the cache + // updates the cache head + // Note: On success, it's important that cache.head points + // to the first cell of the slot. + bool find_slot(const llama_ubatch & batch); + + // TODO: maybe not needed + uint32_t get_padding(const llama_cparams & cparams) const; + + // find how many cells are currently in use + uint32_t cell_max() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + // defrag + + struct { + std::vector moves; + } defrag_info; + + // return true if cells have been moved + bool defrag_prepare(int32_t n_max_nodes); + + // commit/restore cache + + struct slot_range { + uint32_t c0 = 0; // note: these are cell indices, not sequence positions + uint32_t c1 = 0; + }; + + // pending cell updates that are not yet committed + struct { + std::vector ranges; + } pending; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1); + + // members + + const llama_hparams & hparams; + + callbacks cbs; + bool has_shift = false; bool do_defrag = false; + + // TODO: remove this and implement llama_kv_cache_recurrent instead bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool v_trans = true; // the value tensor is transposed bool can_shift = false; @@ -46,173 +186,35 @@ struct llama_kv_cache { // computed before each graph build uint32_t n = 0; + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + +private: ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; - std::vector cells; - - std::vector k_l; // per layer - std::vector v_l; - - std::vector ctxs; + std::vector ctxs; std::vector bufs; - size_t total_size() const { - size_t size = 0; - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; - return size; - } - - // TODO: better data structures to reduce the cost of this operation - llama_pos max_pos() const { - llama_pos max_pos = -1; - for (const auto & cell : cells) { - max_pos = std::max(max_pos, cell.pos); - } - - return max_pos; - } + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; -// a structure holds information about the slot found in llama_kv_cache_find_slot -struct llama_kv_cache_slot_info { - std::pair boundaries; // slot boundaries [begin, end) - bool found = false; // the slot was found - - explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} - llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} - - operator bool() const { return found; } -}; - -// TODO: maybe not needed -uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams); - -bool llama_kv_cache_init( - struct llama_kv_cache & cache, - const llama_model & model, - const llama_cparams & cparams, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size, - bool offload); - -// find an empty slot of size "n_tokens" in the cache -// updates the cache head -// returns a structure holding information about the slot found -// Note: On success, it's important that cache.head points -// to the first cell of the slot. -struct llama_kv_cache_slot_info llama_kv_cache_find_slot( - struct llama_kv_cache & cache, - const struct llama_ubatch & batch); - -// find how many cells are currently in use -uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache); - -void llama_kv_cache_clear(struct llama_kv_cache & cache); - -bool llama_kv_cache_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1); - -void llama_kv_cache_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1); - -void llama_kv_cache_seq_keep( - struct llama_kv_cache & cache, - llama_seq_id seq_id); - -void llama_kv_cache_seq_add( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta); - -void llama_kv_cache_seq_div( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d); - -llama_pos llama_kv_cache_seq_pos_max( - struct llama_kv_cache & cache, - llama_seq_id seq_id); - -void llama_kv_cache_defrag(struct llama_kv_cache & cache); - -int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv); - -int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv); - -bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv); +// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified +//class llama_kv_cache_recurrent : public llama_kv_cache_unified { +//public: +// using llama_kv_cache_unified::llama_kv_cache_unified; +//}; // // kv cache view // -struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max); - -void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv); - -// -// kv cache restore -// - -// saves the kv_cache state for future recovery. -// used to rollback llama_kv_cache_find_slot changes. -struct llama_kv_slot_restorer { - struct llama_kv_cache_state { - uint32_t head = 0; - uint32_t n = 0; - } old_state; - - // for non-recurrent models only - // list of slots to restore - std::vector> slot_boundaries; - - bool do_restore = false; - - explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) { - old_state.head = cache.head; - old_state.n = cache.n; - } - - // saves a slot information for future restoration - void save(const struct llama_kv_cache_slot_info & slot) { - if (slot) { - do_restore = true; - if (slot.boundaries.first != slot.boundaries.second) { - slot_boundaries.push_back(slot.boundaries); - } - } - } - - // must be explicitly called to restore the kv_cache state - // and rollback changes from all llama_kv_cache_find_slot calls - void restore(struct llama_kv_cache & cache) { - if (do_restore) { - cache.head = old_state.head; - cache.n = old_state.n; - - if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased - llama_kv_cache_seq_rm(cache, -1, -1, -1); - } else { - for (auto & slot : slot_boundaries) { - llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second); - } - } - } - } -}; +llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max); +void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv); diff --git a/llama/llama.cpp/src/llama-memory.cpp b/llama/llama.cpp/src/llama-memory.cpp new file mode 100644 index 000000000..10173253e --- /dev/null +++ b/llama/llama.cpp/src/llama-memory.cpp @@ -0,0 +1 @@ +#include "llama-memory.h" diff --git a/llama/llama.cpp/src/llama-memory.h b/llama/llama.cpp/src/llama-memory.h new file mode 100644 index 000000000..dfa8c4e90 --- /dev/null +++ b/llama/llama.cpp/src/llama-memory.h @@ -0,0 +1,21 @@ +#pragma once + +#include "llama.h" + +// general concept of LLM memory +// the KV cache is a type of LLM memory, but there can be other types +class llama_memory_i { +public: + virtual void clear() = 0; + virtual void defrag() = 0; + + virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; + virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; + virtual void seq_keep(llama_seq_id seq_id) = 0; + virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0; + virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; + + virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; + + virtual bool get_can_edit() const = 0; +}; diff --git a/llama/llama.cpp/src/llama-mmap.cpp b/llama/llama.cpp/src/llama-mmap.cpp index b716630a8..9da97f1bc 100644 --- a/llama/llama.cpp/src/llama-mmap.cpp +++ b/llama/llama.cpp/src/llama-mmap.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #ifdef __has_include #if __has_include() @@ -34,6 +35,10 @@ #include #endif +#if defined(__APPLE__) +#include +#endif + // TODO: consider moving to llama-impl.h if needed in more places #if defined(_WIN32) static std::string llama_format_win_err(DWORD err) { @@ -471,7 +476,11 @@ struct llama_mlock::impl { char* errmsg = std::strerror(errno); bool suggest = (errno == ENOMEM); - +#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) + // visionOS/tvOS dont't support RLIMIT_MEMLOCK + // Skip resource limit checks on visionOS/tvOS + suggest = false; +#else struct rlimit lock_limit; if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { suggest = false; @@ -479,6 +488,7 @@ struct llama_mlock::impl { if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { suggest = false; } +#endif LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); diff --git a/llama/llama.cpp/src/llama-model-loader.cpp b/llama/llama.cpp/src/llama-model-loader.cpp index 45d087213..2e11507d9 100644 --- a/llama/llama.cpp/src/llama-model-loader.cpp +++ b/llama/llama.cpp/src/llama-model-loader.cpp @@ -448,7 +448,8 @@ llama_model_loader::llama_model_loader( std::vector & splits, bool use_mmap, bool check_tensors, - const struct llama_model_kv_override * param_overrides_p) { + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -460,6 +461,8 @@ llama_model_loader::llama_model_loader( } } + tensor_buft_overrides = param_tensor_buft_overrides_p; + // Load the main GGUF struct ggml_context * ctx = NULL; struct gguf_init_params params = { @@ -603,7 +606,9 @@ llama_model_loader::llama_model_loader( if (trace > 0) { const uint16_t sid = w.idx; - LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str()); + LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ] %8.2f MiB\n", __func__, + sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str(), + ggml_nbytes(tensor)/1024.0f/1024.0f); } } @@ -643,9 +648,9 @@ llama_model_loader::llama_model_loader( ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); { - const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV - if (kid >= 0) { - ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid); + uint32_t ftype_val = 0; + if (get_key(LLM_KV_GENERAL_FILE_TYPE, ftype_val, false)) { + ftype = (llama_ftype) ftype_val; } } diff --git a/llama/llama.cpp/src/llama-model-loader.h b/llama/llama.cpp/src/llama-model-loader.h index fe35404b2..0f52b011b 100644 --- a/llama/llama.cpp/src/llama-model-loader.h +++ b/llama/llama.cpp/src/llama-model-loader.h @@ -77,8 +77,9 @@ struct llama_model_loader { llama_mmaps mappings; - std::map weights_map; - std::unordered_map kv_overrides; + std::map weights_map; + std::unordered_map kv_overrides; + const llama_model_tensor_buft_override * tensor_buft_overrides; gguf_context_ptr meta; std::vector contexts; @@ -95,7 +96,8 @@ struct llama_model_loader { std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, bool check_tensors, - const struct llama_model_kv_override * param_overrides_p); + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p); template typename std::enable_if::value, bool>::type diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index db4f2685d..cd1d239c8 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -2,15 +2,22 @@ #include "llama-impl.h" #include "llama-mmap.h" +#include "llama-batch.h" +#include "llama-cparams.h" #include "llama-model-loader.h" +#include "llama-kv-cache.h" #include "ggml-cpp.h" #include #include +#include +#include #include +#include #include #include +#include #include #include @@ -26,6 +33,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_109M: return "109M"; case LLM_TYPE_137M: return "137M"; case LLM_TYPE_160M: return "160M"; + case LLM_TYPE_190M: return "190M"; case LLM_TYPE_220M: return "220M"; case LLM_TYPE_250M: return "250M"; case LLM_TYPE_270M: return "270M"; @@ -40,8 +48,10 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_1_4B: return "1.4B"; case LLM_TYPE_1_5B: return "1.5B"; case LLM_TYPE_1_6B: return "1.6B"; + case LLM_TYPE_1_8B: return "1.8B"; case LLM_TYPE_2B: return "2B"; case LLM_TYPE_2_8B: return "2.8B"; + case LLM_TYPE_2_9B: return "2.9B"; case LLM_TYPE_3B: return "3B"; case LLM_TYPE_4B: return "4B"; case LLM_TYPE_6B: return "6B"; @@ -79,6 +89,9 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B"; case LLM_TYPE_57B_A14B: return "57B.A14B"; case LLM_TYPE_27B: return "27B"; + case LLM_TYPE_290B: return "290B"; + case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; + case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; default: return "?B"; } } @@ -243,10 +256,11 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara return cur_buft; } } + return nullptr; } -// CPU: ACCEL -> CPU extra -> GPU host -> CPU +// CPU: ACCEL -> GPU host -> CPU extra -> CPU static buft_list_t make_cpu_buft_list(const std::vector & devices) { buft_list_t buft_list; @@ -262,19 +276,6 @@ static buft_list_t make_cpu_buft_list(const std::vector & de } } - // add extra buffer types - auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(cpu_dev, *extra_bufts); - ++extra_bufts; - } - } - // add a host buffer type // storing the tensors in a host buffer is useful when the processing of large batches // is offloaded to a GPU device, since it reduces the time spent on data transfers @@ -289,6 +290,20 @@ static buft_list_t make_cpu_buft_list(const std::vector & de } } + // add extra buffer types, only if no GPU device is present + // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094 + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + // add the CPU buffer type for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); @@ -301,7 +316,7 @@ static buft_list_t make_cpu_buft_list(const std::vector & de } // GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU -static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) { +static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode split_mode, const float * tensor_split) { buft_list_t buft_list; // add the device split buffer type if requested and available @@ -366,9 +381,12 @@ struct llama_model::impl { layer_dev dev_input = {}; layer_dev dev_output = {}; std::vector dev_layer; + + bool has_tensor_overrides; }; -llama_model::llama_model(const struct llama_model_params & params) : params(params), pimpl(std::make_unique()) { +llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique()) { + pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } llama_model::~llama_model() {} @@ -390,7 +408,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // get metadata as string for (int i = 0; i < gguf_get_n_kv(ctx); i++) { - enum gguf_type type = gguf_get_kv_type(ctx, i); + gguf_type type = gguf_get_kv_type(ctx, i); if (type == GGUF_TYPE_ARRAY) { continue; } @@ -472,6 +490,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); // non-transformer models do not have attention heads @@ -534,6 +556,25 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } } break; + case LLM_ARCH_LLAMA4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full + hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick + hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later + + switch (hparams.n_expert) { + case 16: type = LLM_TYPE_17B_16E; break; + case 128: type = LLM_TYPE_17B_128E; break; + default: type = LLM_TYPE_UNKNOWN; + } + + if (type == LLM_TYPE_17B_128E) { + hparams.use_kq_norm = false; + } + } break; case LLM_ARCH_MLLAMA: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -760,6 +801,22 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN3MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_PHI2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -787,9 +844,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_swa = 2047; } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) { // default value for Phi-3-mini-128k-instruct + // note: this seems incorrect because the window is bigger than the train context? hparams.n_swa = 262144; } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) { // default value for Phi-3-medium-128k-instruct + // note: this seems incorrect because the window is equal to the train context? hparams.n_swa = 131072; } bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -865,11 +924,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_GEMMA2: { hparams.n_swa = 4096; // default value of gemma 2 + hparams.n_swa_pattern = 2; + hparams.attn_soft_cap = true; + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - hparams.attn_soft_cap = true; switch (hparams.n_layer) { case 26: type = LLM_TYPE_2B; break; @@ -880,6 +941,25 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA3: { + hparams.n_swa_pattern = 6; + + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_1B; break; + case 34: type = LLM_TYPE_4B; break; + case 48: type = LLM_TYPE_12B; break; + case 62: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); } break; case LLM_ARCH_STARCODER2: { @@ -945,6 +1025,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_COHERE2: { + hparams.n_swa_pattern = 4; + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -983,6 +1065,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 16: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; + case 64: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -1106,6 +1189,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_PLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_CHATGLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1127,6 +1219,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_9B; break; + case 61: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1227,6 +1328,36 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_190M; break; + case 24: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_450M; break; + case 2048: type = LLM_TYPE_1_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 28: + switch (hparams.n_embd) { + case 1536: type = LLM_TYPE_1_5B; break; + case 3584: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { @@ -1277,6 +1408,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); } break; + case LLM_ARCH_BAILINGMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_16B; break; + case 88: type = LLM_TYPE_290B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_MISTRAL3: break; default: throw std::runtime_error("unsupported model architecture"); } @@ -1349,13 +1495,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { + const bool is_swa = il < (int) hparams.n_layer && hparams.is_swa(il); if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { - LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(cpu_dev)); + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); return {cpu_dev, &pimpl->cpu_buft_list}; } const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); auto * dev = devices.at(layer_gpu); - LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(dev)); + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); return {dev, &pimpl->gpu_buft_list.at(dev)}; }; @@ -1459,7 +1606,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // skip unused tensors if (info.op == GGML_OP_NONE) { - LLAMA_LOG_WARN("model has unused tensor %s -- ignoring\n", tn.str().c_str()); + const size_t nbytes = ggml_nbytes(t_meta); + LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); + + ml.size_data -= nbytes; ml.n_created++; return nullptr; @@ -1501,9 +1651,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) { GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); } - ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list); + ggml_backend_buffer_type_t buft = nullptr; + + // check overrides + if (ml.tensor_buft_overrides) { + std::string tensor_name = tn.str(); + for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { + std::regex pattern(overrides->pattern); + if (std::regex_search(tensor_name, pattern)) { + LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft)); + buft = overrides->buft; + break; + } + } + } + if (!buft) { - throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + buft = select_weight_buft(hparams, t_meta, op, *buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + } } // avoid using a host buffer when using mmap @@ -1599,6 +1766,56 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_LLAMA4: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0"); + for (int i = 0; i < n_layer; ++i) { + bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0; + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + if (is_moe_layer) { + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert + const int64_t n_ff_shexp = n_ff_exp; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; case LLM_ARCH_MLLAMA: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0); @@ -2023,7 +2240,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -2208,9 +2430,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -2239,6 +2461,77 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); } } break; + case LLM_ARCH_QWEN3: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN3MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } + } break; case LLM_ARCH_PHI2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -2281,13 +2574,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PHI3: { - const int64_t n_embd_head = n_embd / n_head; - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); @@ -2328,7 +2620,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); if (layer.wqkv == nullptr) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); @@ -2385,7 +2677,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -2411,7 +2708,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_CODESHELL: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if tok embd is NULL, init from output + if (tok_embd == NULL) { + tok_embd = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -2543,6 +2845,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA3: { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } } break; case LLM_ARCH_STARCODER2: { @@ -2737,6 +3070,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_OLMO2: { + const int64_t n_embd_head = n_embd / n_head; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -2751,7 +3086,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -3011,6 +3346,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_PLM: + { + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_BITNET: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3176,16 +3540,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); if (layer.wqkv == nullptr) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); @@ -3197,6 +3561,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); } } break; + case LLM_ARCH_GLM4: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3240,7 +3643,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -3291,12 +3699,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); - layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, TENSOR_NOT_REQUIRED); GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL)); layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0); @@ -3326,7 +3734,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); const int time_mix_extra_dim = hparams.time_mix_extra_dim; @@ -3352,7 +3760,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); - layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, TENSOR_NOT_REQUIRED); layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); @@ -3361,47 +3769,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); // optional bias tensors - layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, TENSOR_NOT_REQUIRED); layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_CHAMELEON: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); @@ -3430,6 +3805,179 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.bskcn_tv = create_tensor(tn(LLM_TENSOR_BSKCN_TV, "weight", i), {2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_RWKV7: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // Block 0, LN0 + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + if (i == 0) { + // actually not used + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); + } else { + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + } + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0); + + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); + + layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); + layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); + } + + } break; + case LLM_ARCH_ARWKV7: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + if (i == 0) { + // actually not used + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); + } else { + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + } + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED); + + try { + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); + } catch(std::runtime_error & e) { + // ARWKV models may not have gate tensors + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); + } + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + } break; + case LLM_ARCH_CHAMELEON: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); @@ -3538,7 +4086,46 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); } break; - case LLM_ARCH_MISTRAL3: break; + case LLM_ARCH_BAILINGMOE: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -3696,8 +4283,8 @@ size_t llama_model::size() const { return pimpl->n_bytes; } -size_t llama_model::max_nodes() const { - return std::max(8192, tensors_by_name.size()*5); +size_t llama_model::n_tensors() const { + return tensors_by_name.size(); } size_t llama_model::n_devices() const { @@ -3752,6 +4339,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); @@ -3762,6 +4350,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); + LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); @@ -3809,7 +4398,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((enum llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); } @@ -3818,12 +4407,24 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } + if (arch == LLM_ARCH_QWEN3MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } + if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); } + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } + vocab.print_info(); } @@ -3885,9 +4486,13 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const { }); } -const struct ggml_tensor * llama_model::get_tensor(const char * name) const { +bool llama_model::has_tensor_overrides() const { + return pimpl->has_tensor_overrides; +} + +const ggml_tensor * llama_model::get_tensor(const char * name) const { auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), - [name](const std::pair & it) { + [name](const std::pair & it) { return it.first == name; }); if (it == tensors_by_name.end()) { @@ -3897,13 +4502,9038 @@ const struct ggml_tensor * llama_model::get_tensor(const char * name) const { return it->second; } +struct llm_build_llama : public llm_graph_context { + llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // temperature tuning + ggml_tensor * inp_attn_scale = nullptr; + if (arch == LLM_ARCH_LLAMA4) { + inp_attn_scale = build_inp_attn_scale(); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + bool use_rope = arch == LLM_ARCH_LLAMA4 + ? (il + 1) % hparams.n_no_rope_layer_step != 0 + : true; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else if (inp_attn_scale) { + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) { + // Llama4TextL2Norm + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architecture + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else if (arch == LLM_ARCH_LLAMA4) { + // llama4 MoE + ggml_tensor * ffn_inp_normed = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(ffn_inp_normed, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + + // Shared experts + ggml_tensor * shexp_out = build_ffn(ffn_inp_normed, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shexp_out, "ffn_moe_shexp", il); + + cur = ggml_add(ctx0, moe_out, shexp_out); + cb(cur, "ffn_moe_out_merged", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + + // For Granite architecture + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architecture + if (hparams.f_logit_scale) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_mllama: public llm_graph_context { + llm_build_mllama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inpCAS; + + inpL = build_inp_embd(model.tok_embd); + inpCAS = build_inp_cross_attn_state(); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + const llama_kv_cache_unified * kv_self = static_cast(memory); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (hparams.cross_attention_layers(il)) { + if (!ubatch.embd && !cparams.cross_attn) { + continue; + } + + // cross attention layer + ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur); + cb(Qcur, "Qcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + + Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); + cb(Qcur, "Qcur", il); + + Qcur = build_norm(Qcur, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur, * Vcur; + if (ubatch.embd) { + Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS); + cb(Kcur, "Kcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404); + cb(Kcur, "Kcur", il); + + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + cb(Kcur, "Kcur", il); + + Kcur = build_norm(Kcur, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self->k_l[il])); + + Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS); + cb(Vcur, "Vcur", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404); + cb(Vcur, "Vcur", il); + + Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); + cb(Vcur, "Vcur", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self->v_l[il])); + } else { + Kcur = ggml_view_tensor(ctx0, kv_self->k_l[il]); + cb(Kcur, "Kcur (view)", il); + + Vcur = ggml_view_tensor(ctx0, kv_self->v_l[il]); + cb(Vcur, "Vcur (view)", il); + } + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur); + cb(kq, "kq", il); + + // TODO: apply causal masks + struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); + cb(kq_soft_max, "kq_soft_max", il); + + Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur)); + cb(Vcur, "Vcur", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur); + cb(cur, "cur", il); + + // TODO: do this in place once? + cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate)); + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + // TODO: do this inplace once? + cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } else { + // self attention layer + + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_deci : public llm_graph_context { + llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_head = hparams.n_head(il); + + if (n_head == 0) { + // attention-free layer of Llama-3_1-Nemotron-51B + cur = inpL; + } else { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + } + + if (n_head > 0 && n_head_kv == 0) { + // "linear attention" of Llama-3_1-Nemotron-51B + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "wo", il); + } else if (n_head > 0) { + // self-attention + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architecture + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + // modified to support attention-free layer of Llama-3_1-Nemotron-51B + ggml_tensor * ffn_inp = cur; + if (n_head > 0) { + ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + } + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + // For Granite architecture + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architecture + if (hparams.f_logit_scale) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_baichuan : public llm_graph_context { + llm_build_baichuan(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr; + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + switch (model.type) { + case LLM_TYPE_7B: + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + break; + case LLM_TYPE_13B: + break; + default: + GGML_ABORT("fatal error"); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_xverse : public llm_graph_context { + llm_build_xverse(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_falcon : public llm_graph_context { + llm_build_falcon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * attn_norm; + + attn_norm = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(attn_norm, "attn_norm", il); + + // self-attention + { + if (model.layers[il].attn_norm_2) { + // Falcon-40B + cur = build_norm(inpL, + model.layers[il].attn_norm_2, + model.layers[il].attn_norm_2_b, + LLM_NORM, il); + cb(cur, "attn_norm_2", il); + } else { + cur = attn_norm; + } + + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // using mode = 2 for neox mode + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids); + } + + ggml_tensor * ffn_inp = cur; + + // feed forward + { + cur = build_ffn(attn_norm, // !! use the attn norm, not the result + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // norm + cur = build_norm(cur, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_grok : public llm_graph_context { + llm_build_grok(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // multiply by embedding_multiplier_scale of 78.38367176906169 + inpL = ggml_scale(ctx0, inpL, 78.38367176906169f); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Grok + // if attn_out_norm is present then apply it before adding the input + if (model.layers[il].attn_out_norm) { + cur = build_norm(cur, + model.layers[il].attn_out_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_out_norm", il); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_GELU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + + // Grok + // if layer_out_norm is present then apply it before adding the input + // Idea: maybe ffn_out_norm is a better name + if (model.layers[il].layer_out_norm) { + cur = build_norm(cur, + model.layers[il].layer_out_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "layer_out_norm", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // Grok + // multiply logits by output_multiplier_scale of 0.5773502691896257 + + cur = ggml_scale(ctx0, cur, 0.5773502691896257f); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_dbrx : public llm_graph_context { + llm_build_dbrx(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(cur, "wqkv_clamped", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].attn_out_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_out_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_starcoder : public llm_graph_context { + llm_build_starcoder(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_refact : public llm_graph_context { + llm_build_refact(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_bert : public llm_graph_context { + llm_build_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = nullptr; + + if (model.arch != LLM_ARCH_JINA_BERT_V2) { + inp_pos = build_inp_pos(); + } + + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + + // token types are hardcoded to zero ("Sentence A") + ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); + inpL = ggml_add(ctx0, inpL, type_row0); + if (model.arch == LLM_ARCH_BERT) { + inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); + } + cb(inpL, "inp_embd", -1); + + // embed layer norm + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + cb(inpL, "inp_norm", -1); + + auto * inp_attn = build_attn_inp_no_cache(); + + // iterate layers + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * cur = inpL; + + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + // self-attention + if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); + + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, il); + } + + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); + + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, il); + } + + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + } else { + // compute Q and K and RoPE them + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + + if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); + + // attention layer norm + cur = build_norm(cur, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, il); + + if (model.layers[il].attn_norm_2 != nullptr) { + cur = ggml_add(ctx0, cur, inpL); // re-add the layer input + cur = build_norm(cur, model.layers[il].attn_norm_2, model.layers[il].attn_norm_2_b, LLM_NORM, il); + } + + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (model.arch == LLM_ARCH_BERT) { + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + } else { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + } + cb(cur, "ffn_out", il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // output layer norm + cur = build_norm(cur, model.layers[il].layer_out_norm, model.layers[il].layer_out_norm_b, LLM_NORM, il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_bloom : public llm_graph_context { + llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp_attn = build_attn_inp_kv_unified(); + + inpL = build_norm(inpL, + model.tok_norm, + model.tok_norm_b, + LLM_NORM, -1); + cb(inpL, "inp_norm", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // Add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_mpt : public llm_graph_context { + llm_build_mpt(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * pos; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp_attn = build_attn_inp_kv_unified(); + + if (model.pos_embd) { + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + } + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * attn_norm; + + attn_norm = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(attn_norm, "attn_norm", il); + + // self-attention + { + cur = attn_norm; + + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + if (model.layers[il].bqkv){ + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + + if (hparams.f_clamp_kqv > 0.0f) { + cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(cur, "wqkv_clamped", il); + } + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Q/K Layernorm + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, il); + cb(Qcur, "Qcur", il); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // Add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // feed forward + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_act, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_stablelm : public llm_graph_context { + llm_build_stablelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + ggml_tensor * inpSA = cur; + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + NULL, + LLM_NORM, il); + cb(Qcur, "Qcur", il); + } + + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + NULL, + LLM_NORM, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + if (model.layers[il].ffn_norm) { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + } else { + // parallel residual + cur = inpSA; + } + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen : public llm_graph_context { + llm_build_qwen(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // using mode = 2 for neox mode + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward forward + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen2 : public llm_graph_context { + llm_build_qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen2vl : public llm_graph_context { + llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen2moe : public llm_graph_context { + llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * cur_gate_inp = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(cur_gate_inp, "ffn_shexp_gate_inp", il); + + // sigmoid + ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp); + cb(cur_gate, "ffn_shexp_gate", il); + + ggml_tensor * cur_ffn = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_ffn, "ffn_shexp", il); + + ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate); + cb(ffn_shexp_out, "ffn_shexp_out", il); + + moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out); + cb(moe_out, "ffn_out", il); + + cur = moe_out; + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen3 : public llm_graph_context { + llm_build_qwen3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen3moe : public llm_graph_context { + llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_phi2 : public llm_graph_context { + llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * attn_norm_output; + ggml_tensor * ffn_output; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + attn_norm_output = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(attn_norm_output, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv) { + cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // with phi2, we scale the Q to avoid precision issues + // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 + Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids); + } + + // FF + { + ffn_output = build_ffn(attn_norm_output, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(ffn_output, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_output); + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output_no_bias", -1); + + cur = ggml_add(ctx0, cur, model.output_b); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_phi3 : public llm_graph_context { + llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + auto * residual = inpL; + + // self-attention + { + // rope freq factors for 128k context + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + ggml_tensor* attn_norm_output = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM_RMS, il); + cb(attn_norm_output, "attn_norm", il); + + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv) { + cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); + cb(Qcur, "Qcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor* inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, residual); + residual = cur; + + cur = build_norm(cur, + model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, residual, cur); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + if (model.output_b != nullptr) { + cb(cur, "result_output_no_bias", -1); + cur = ggml_add(ctx0, cur, model.output_b); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_plamo : public llm_graph_context { + llm_build_plamo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * attention_norm = cur; + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + ggml_tensor * sa_out = cur; + + cur = attention_norm; + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, sa_out); + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gpt2 : public llm_graph_context { + llm_build_gpt2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * pos; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_codeshell : public llm_graph_context { + llm_build_codeshell(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_orion : public llm_graph_context { + llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + // if (model.layers[il].bq) { + // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + // cb(Qcur, "Qcur", il); + // } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + // if (model.layers[il].bk) { + // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + // cb(Kcur, "Kcur", il); + // } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + // if (model.layers[il].bv) { + // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + // cb(Vcur, "Vcur", il); + // } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_internlm2 : public llm_graph_context { + llm_build_internlm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_minicpm3 : public llm_graph_context { + llm_build_minicpm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + //TODO: if the model varies, these parameters need to be read from the model + const int64_t n_embd_base = 256; + const float scale_embd = 12.0f; + const float scale_depth = 1.4f; + const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k)); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // scale the input embeddings + inpL = ggml_scale(ctx0, inpL, scale_embd); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * q = NULL; + // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = build_norm(q, + model.layers[il].attn_q_a_norm, NULL, + LLM_NORM_RMS, il); + cb(q, "q", il); + + // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + 0); + cb(q_nope, "q_nope", il); + + // and {n_head * n_embd_head_qk_rope, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_pe_compresseed, "kv_pe_compresseed", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, + kv_pe_compresseed->nb[1], + 0); + cb(kv_compressed, "kv_compressed", il); + + // and {n_embd_head_qk_rope, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont + kv_compressed = ggml_cont(ctx0, kv_compressed); + kv_compressed = build_norm(kv_compressed, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, il); + cb(kv_compressed, "kv_compressed", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + q_states, k_states, v_states, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // scale_res - scale the hidden states for residual connection + const float scale_res = scale_depth/sqrtf(float(n_layer)); + cur = ggml_scale(ctx0, cur, scale_res); + cb(cur, "hidden_scaled", il); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + // scale the hidden states for residual connection + cur = ggml_scale(ctx0, cur, scale_res); + cb(cur, "hidden_scaled_ffn", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head scaling + const float scale_lmhead = float(n_embd_base)/float(n_embd); + cur = ggml_scale(ctx0, cur, scale_lmhead); + cb(cur, "lmhead_scaling", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gemma : public llm_graph_context { + llm_build_gemma(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); + cb(Qcur, "Qcur_scaled", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gemma2 : public llm_graph_context { + llm_build_gemma2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e + switch (model.type) { + case LLM_TYPE_2B: + case LLM_TYPE_9B: + case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break; + default: GGML_ABORT("fatal error"); + }; + cb(Qcur, "Qcur_scaled", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f, il); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // final logit soft-capping + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gemma3 : public llm_graph_context { + llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + if (ubatch.token) { + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + } + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: is causal == true correct? might need some changes + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + const bool is_swa = hparams.is_swa(il); + + const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; + const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// TODO: move up next to build_starcoder +struct llm_build_starcoder2 : public llm_graph_context { + llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_mamba : public llm_graph_context { + const llama_model & model; + + llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * state_copy = build_inp_s_copy(); + ggml_tensor * state_mask = build_inp_s_mask(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il); + cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // residual + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = build_norm(inpL, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + // TODO: split + ggml_tensor * build_mamba_layer( + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il) const { + const llama_kv_cache_unified * kv_self = static_cast(memory); + + const auto kv_head = kv_self->head; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_seqs = ubatch.n_seqs; + // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) + const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; + // Use the same RMS norm as the final layer norm + const float norm_rms_eps = hparams.f_norm_rms_eps; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = kv_self->k_l[il]; + ggml_tensor * ssm_states_all = kv_self->v_l[il]; + + // (ab)using the KV cache to store the states + ggml_tensor * conv = build_copy_mask_state( + gf, conv_states_all, state_copy, state_mask, + hparams.n_embd_k_s(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); + ggml_tensor * ssm = build_copy_mask_state( + gf, ssm_states_all, state_copy, state_mask, + hparams.n_embd_v_s(), n_seqs); + ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} + ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur); + // split the above in two + // => {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); + ggml_tensor * z = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + + // bias + x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); + + x = ggml_silu(ctx0, x); + } + + // ssm + { + // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} + ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); + // split + ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); + ggml_tensor * B = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + ggml_tensor * C = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + + // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers + if (ssm_dt_b_c_rms) { + dt = ggml_rms_norm(ctx0, dt, norm_rms_eps); + B = ggml_rms_norm(ctx0, B, norm_rms_eps); + C = ggml_rms_norm(ctx0, C, norm_rms_eps); + } + + // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} + dt = build_lora_mm(model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + + // TODO: skip computing output earlier for unused tokens + + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + //cb(cur, "mamba_out", il); + + return cur; + } +}; + +struct llm_build_command_r : public llm_graph_context { + llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + const float f_logit_scale = hparams.f_logit_scale; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + ggml_tensor * ffn_inp = cur; + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + NULL, + LLM_NORM, il); + cb(Qcur, "Qcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + NULL, + LLM_NORM, il); + cb(Kcur, "Kcur", il); + } + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + + ggml_tensor * attn_out = cur; + + // feed-forward network + { + cur = build_ffn(ffn_inp, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + // add together residual + FFN + self-attention + cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, attn_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + if (f_logit_scale) { + cur = ggml_scale(ctx0, cur, f_logit_scale); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_cohere2 : public llm_graph_context { + llm_build_cohere2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + const float f_logit_scale = hparams.f_logit_scale; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + const bool is_swa = hparams.is_swa(il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il); + cb(cur, "attn_norm", il); + ggml_tensor * ffn_inp = cur; + + // self-attention + { + // rope freq factors for 128k context + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (is_swa) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + + ggml_tensor * attn_out = cur; + + // feed-forward network + { + cur = build_ffn(ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, + NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, + il); + cb(cur, "ffn_out", il); + } + + // add together residual + FFN + self-attention + cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, attn_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + if (f_logit_scale) { + cur = ggml_scale(ctx0, cur, f_logit_scale); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// ref: https://allenai.org/olmo +// based on the original build_llama() function, changes: +// * non-parametric layer norm +// * clamp qkv +// * removed bias +// * removed MoE +struct llm_build_olmo : public llm_graph_context { + llm_build_olmo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + NULL, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + NULL, NULL, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + NULL, NULL, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_olmo2 : public llm_graph_context { + llm_build_olmo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = inpL; + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_ffn(ffn_inp, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// based on the build_qwen2moe() function, changes: +// * removed shared experts +// * removed bias +// * added q, k norm +struct llm_build_olmoe : public llm_graph_context { + llm_build_olmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_openelm : public llm_graph_context { + llm_build_openelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_head_qkv = 2*n_head_kv + n_head; + + cur = inpL; + ggml_tensor * residual = cur; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0)); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head)); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv))); + cb(Vcur, "Vcur", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur", il); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, NULL, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, NULL, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, NULL, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Qcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = inpL; + + // norm + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gptneox : public llm_graph_context { + llm_build_gptneox(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // ffn + if (hparams.use_par_res) { + // attention and ffn are computed in parallel + // x = x + attn(ln1(x)) + ffn(ln2(x)) + + ggml_tensor * attn_out = cur; + + cur = build_norm(inpL, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, attn_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } else { + // attention and ffn are computed sequentially + // x = x + attn(ln1(x)) + // x = x + ffn(ln2(x)) + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_arctic : public llm_graph_context { + llm_build_arctic(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp); + cb(ffn_out, "ffn_out", il); + + // MoE + cur = build_norm(inpSA, + model.layers[il].ffn_norm_exps, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm_exps", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_deepseek : public llm_graph_context { + llm_build_deepseek(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_deepseek2 : public llm_graph_context { + llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + bool is_lite = (hparams.n_layer == 27); + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); + const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * q = NULL; + if (!is_lite) { + // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = build_norm(q, + model.layers[il].attn_q_a_norm, NULL, + LLM_NORM_RMS, il); + cb(q, "q", il); + + // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + } else { + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + } + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + 0); + cb(q_nope, "q_nope", il); + + // and {n_head * n_embd_head_qk_rope, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_pe_compresseed, "kv_pe_compresseed", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, + kv_pe_compresseed->nb[1], + 0); + cb(kv_compressed, "kv_compressed", il); + + // and {n_embd_head_qk_rope, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont + kv_compressed = ggml_cont(ctx0, kv_compressed); + kv_compressed = build_norm(kv_compressed, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, il); + cb(kv_compressed, "kv_compressed", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + q_states, k_states, v_states, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_bitnet : public llm_graph_context { + llm_build_bitnet(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].wq_scale) { + Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); + } + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + // B1.K + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].wk_scale) { + Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); + } + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + // B1.V + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].wv_scale) { + Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); + } + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + NULL, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + + cur = build_norm(cur, + model.layers[il].attn_sub_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_sub_norm", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + if (model.layers[il].wo_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + } + if (model.layers[il].bo) { + cur = ggml_add(ctx0, cur, model.layers[il].bo); + } + cb(cur, "attn_o_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward forward + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, + NULL, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_sub_out", il); + + cur = build_norm(cur, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_sub_norm", il); + + cur = build_lora_mm(model.layers[il].ffn_down, cur); + if (model.layers[il].ffn_down_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + } + cb(cur, "ffn_down", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + // FIXME: do not use model.tok_embd directly, duplicate as model.output + cur = build_lora_mm(model.tok_embd, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_t5_enc : public llm_graph_context { + llm_build_t5_enc(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc(); + + auto * inp_attn = build_attn_inp_no_cache(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; + ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo_enc, nullptr, + Qcur, Kcur, Vcur, kq_b, 1.0f, il); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up_enc, NULL, NULL, + model.layers[il].ffn_gate_enc, NULL, NULL, + model.layers[il].ffn_down_enc, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cb(cur, "result_embd", -1); + + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_t5_dec : public llm_graph_context { + llm_build_t5_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * embd_enc = build_inp_cross_embd(); + ggml_tensor * pos_bucket_dec = build_inp_pos_bucket_dec(); + + const int64_t n_outputs_enc = embd_enc->ne[1]; + + auto * inp_attn_self = build_attn_inp_kv_unified(); + auto * inp_attn_cross = build_attn_inp_cross(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; + ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); + + cur = build_attn(inp_attn_self, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, kq_b, 1.0f, il); + cb(cur, "kqv_out", il); + } + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "cross_inp", il); + + ggml_tensor * inpCA = cur; + + // norm + cur = build_norm(cur, + model.layers[il].attn_norm_cross, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm_cross", il); + + // cross-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); + + cur = build_attn(inp_attn_cross, gf, + model.layers[il].wo_cross, nullptr, + Qcur, Kcur, Vcur, nullptr, 1.0f, il); + cb(cur, "kqv_out", il); + + //ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + //ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + + //ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + //cb(kq, "kq", il); + + //kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias); + //cb(kq, "kq_soft_max_ext", il); + + //ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc))); + //cb(v, "v", il); + + //ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq); + //cb(kqv, "kqv", il); + + //ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + //cb(kqv_merged, "kqv_merged", il); + + //cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + //cb(cur, "kqv_merged_cont", il); + + //ggml_build_forward_expand(gf, cur); + + //cur = build_lora_mm(model.layers[il].wo_cross, cur); + //cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cb(cur, "result_embd", -1); + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_jais : public llm_graph_context { + llm_build_jais(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/float(n_embd_head), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_chatglm : public llm_graph_context { + llm_build_chatglm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + } else { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = build_norm(inpL, + model.output_norm, + NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_glm4 : public llm_graph_context { + llm_build_glm4(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + } else { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Post-attention norm (new!) + cur = build_norm(cur, + model.layers[il].attn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Add the input (residual connection after post-attention norm) + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + // Pre-MLP norm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // MLP + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Post-MLP norm + cur = build_norm(cur, + model.layers[il].ffn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_mlp_norm", il); + } + + // Add residual connection after post-MLP norm + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + // Final norm + cur = build_norm(inpL, + model.output_norm, + NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // Output projection + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_nemotron : public llm_graph_context { + llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + //GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_exaone : public llm_graph_context { + llm_build_exaone(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_rwkv6_base : public llm_graph_context { + const llama_model & model; + + llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) { + } + + ggml_tensor * build_rwkv6_channel_mix( + const llama_layer * layer, + ggml_tensor * cur, + ggml_tensor * x_prev, + llm_arch arch) const { + ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + switch (arch) { + case LLM_ARCH_RWKV6: + { + ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); + ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur); + + ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr)); + ggml_tensor * k = ggml_sqr( + ctx0, + ggml_relu( + ctx0, + build_lora_mm(layer->channel_mix_key, xk) + ) + ); + cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k)); + } break; + default: + GGML_ABORT("fatal error"); + } + + return cur; + } + + ggml_tensor * build_rwkv6_time_mix( + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * x_prev, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il) const { + const llama_kv_cache_unified * kv_self = static_cast(memory); + + const auto n_tokens = ubatch.n_tokens; + const auto n_seqs = ubatch.n_seqs; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_embd = hparams.n_embd; + const auto head_size = hparams.wkv_head_size; + const auto n_head = n_embd / head_size; + const auto n_head_kv = hparams.n_head_kv(il); + + const auto kv_head = kv_self->head; + + const auto & layer = model.layers[il]; + + bool is_qrwkv = layer.time_mix_first == nullptr; + + ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + + sx = ggml_reshape_2d(ctx0, sx, n_embd, n_tokens); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + + ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur); + + xxx = ggml_reshape_4d( + ctx0, + ggml_tanh( + ctx0, + ggml_mul_mat(ctx0, layer.time_mix_w1, xxx) + ), + layer.time_mix_w1->ne[1] / 5, 1, 5, n_tokens + ); + + xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2)); + + xxx = ggml_mul_mat( + ctx0, + ggml_reshape_4d( + ctx0, + layer.time_mix_w2, + layer.time_mix_w2->ne[0], layer.time_mix_w2->ne[1], 1, 5 + ), + xxx + ); + + ggml_tensor *xw, *xk, *xv, *xr, *xg; + if (layer.time_mix_lerp_fused) { + // fusing these weights makes some performance improvement + sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens); + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer.time_mix_lerp_fused), sx), cur); + xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); + xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); + xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); + xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); + xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + } else { + // for backward compatibility + xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); + xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); + xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); + xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); + xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + + xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer.time_mix_lerp_w), sx), cur); + xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer.time_mix_lerp_k), sx), cur); + xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer.time_mix_lerp_v), sx), cur); + xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer.time_mix_lerp_r), sx), cur); + xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer.time_mix_lerp_g), sx), cur); + } + + ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr); + ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk); + ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv); + if (layer.time_mix_receptance_b) { + r = ggml_add(ctx0, r, layer.time_mix_receptance_b); + } + if (layer.time_mix_key_b) { + k = ggml_add(ctx0, k, layer.time_mix_key_b); + } + if (layer.time_mix_value_b) { + v = ggml_add(ctx0, v, layer.time_mix_value_b); + } + + ggml_tensor * g = build_lora_mm(layer.time_mix_gate, xg); + if (is_qrwkv) { + g = ggml_sigmoid(ctx0, g); + } else { + g = ggml_silu(ctx0, g); + } + + if (n_head_kv != 0 && n_head_kv != n_head) { + GGML_ASSERT(n_head % n_head_kv == 0); + k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens); + v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens); + ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens); + k = ggml_repeat(ctx0, k, tmp); + v = ggml_repeat(ctx0, v, tmp); + } + + k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens); + v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens); + r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens); + + ggml_tensor * w = ggml_mul_mat( + ctx0, + layer.time_mix_decay_w2, + ggml_tanh( + ctx0, + ggml_mul_mat(ctx0, layer.time_mix_decay_w1, xw) + ) + ); + + w = ggml_add(ctx0, w, layer.time_mix_decay); + w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w))); + w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens); + + if (is_qrwkv) { + // k = k * (1 - w) + k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); + } + + ggml_tensor * wkv_state = build_copy_mask_state( + gf, kv_self->v_l[il], state_copy, state_mask, + hparams.n_embd_v_s(), n_seqs); + + ggml_tensor * wkv_output; + if (is_qrwkv) { + wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f)); + } else { + wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer.time_mix_first, w, wkv_state); + } + cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); + wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); + + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + wkv_state, + ggml_view_1d( + ctx0, + kv_self->v_l[il], + hparams.n_embd_v_s() * n_seqs, + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + ) + ) + ); + + if (!is_qrwkv) { + // group norm with head_count groups + cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens); + cur = ggml_norm(ctx0, cur, 64e-5f); + + // Convert back to regular vectors. + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b); + } else { + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + } + + cur = ggml_mul(ctx0, cur, g); + cur = build_lora_mm(layer.time_mix_output, cur); + + return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs); + } +}; + +struct llm_build_rwkv6 : public llm_build_rwkv6_base { + llm_build_rwkv6(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) { + GGML_ASSERT(hparams.token_shift_count == 2); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + + ggml_tensor * state_copy = build_inp_s_copy(); + ggml_tensor * state_mask = build_inp_s_mask(); + + const auto n_embd = hparams.n_embd; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_seqs = ubatch.n_seqs; + + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); + + ggml_tensor * token_shift = build_rwkv_token_shift_load( + gf, state_copy, state_mask, ubatch, il + ); + + ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); + ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); + + ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il); + cb(att_norm, "attn_norm", il); + + ggml_tensor * x_prev = ggml_concat( + ctx0, + att_shift, + ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0), + 1 + ); + + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il); + cb(ffn_norm, "ffn_norm", il); + + x_prev = ggml_concat( + ctx0, + ffn_shift, + ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0), + 1 + ); + + token_shift = ggml_concat(ctx0, + ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)), + ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)), + 1 + ); + ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids); + ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids); + x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids); + cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids); + } + + cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6); + cur = ggml_add(ctx0, cur, ffn_inp); + + if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) { + cur = ggml_scale(ctx0, cur, 0.5F); + } + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py +struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { + llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) { + GGML_ASSERT(n_embd == hparams.n_embd_k_s()); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * state_copy = build_inp_s_copy(); + ggml_tensor * state_mask = build_inp_s_mask(); + + const auto n_embd = hparams.n_embd; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_seqs = ubatch.n_seqs; + + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); + + ggml_tensor * token_shift = build_rwkv_token_shift_load( + gf, state_copy, state_mask, ubatch, il + ); + + ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); + cb(att_norm, "attn_norm", il); + + ggml_tensor * x_prev = ggml_concat( + ctx0, + token_shift, + ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0), + 1 + ); + + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); + + token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); + ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids); + } + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_rwkv7_base : public llm_graph_context { + const llama_model & model; + + llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) { + } + + ggml_tensor * build_rwkv7_channel_mix( + const llama_layer * layer, + ggml_tensor * cur, + ggml_tensor * x_prev, + llm_arch arch) const { + ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + switch (arch) { + case LLM_ARCH_RWKV7: + { + ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); + + ggml_tensor * k = ggml_sqr( + ctx0, + ggml_relu( + ctx0, + build_lora_mm(layer->channel_mix_key, xk) + ) + ); + + cur = build_lora_mm(layer->channel_mix_value, k); + } break; + default: + GGML_ABORT("fatal error"); + } + + return cur; + } + + ggml_tensor * build_rwkv7_time_mix( + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * x_prev, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + ggml_tensor *& first_layer_value, + const llama_ubatch & ubatch, + int il) const { + const llama_kv_cache_unified * kv_self = static_cast(memory); + + const auto n_tokens = ubatch.n_tokens; + const auto n_seqs = ubatch.n_seqs; + const auto n_embd = hparams.n_embd; + const auto head_size = hparams.wkv_head_size; + const auto head_count = n_embd / head_size; + const auto n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = kv_self->head; + + const auto & layer = model.layers[il]; + + bool has_gating = layer.time_mix_g1 && layer.time_mix_g2; + + ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5); + sx = ggml_repeat(ctx0, sx, dummy); + + ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur); + + ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); + ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); + ggml_tensor * xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); + ggml_tensor * xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); + ggml_tensor * xa = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + ggml_tensor * xg = has_gating ? ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float)) : nullptr; + + ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr); + ggml_tensor * w = ggml_add( + ctx0, + ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))), + layer.time_mix_w0 + ); + w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531)); + + ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk); + ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv); + if (first_layer_value == nullptr) { + first_layer_value = v; + } else { + // Add the first layer value as a residual connection. + v = ggml_add(ctx0, v, + ggml_mul(ctx0, + ggml_sub(ctx0, first_layer_value, v), + ggml_sigmoid(ctx0, ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.time_mix_v2, ggml_mul_mat(ctx0, layer.time_mix_v1, xv)), + layer.time_mix_v0 + ) + ) + ) + ); + } + + ggml_tensor * g = nullptr; + if (layer.time_mix_g1 && layer.time_mix_g2) { + g = ggml_mul_mat(ctx0, layer.time_mix_g2, ggml_sigmoid(ctx0, ggml_mul_mat(ctx0, layer.time_mix_g1, xg))); + } + + ggml_tensor * a = ggml_sigmoid(ctx0, + ggml_add( + ctx0, + ggml_mul_mat(ctx0, layer.time_mix_a2, ggml_mul_mat(ctx0, layer.time_mix_a1, xa)), + layer.time_mix_a0 + ) + ); + + ggml_tensor * kk = ggml_reshape_3d(ctx0, ggml_mul(ctx0, k, layer.time_mix_k_k), head_size, head_count, n_tokens); + kk = ggml_l2_norm(ctx0, kk, 1e-12); + + ggml_tensor * ka = ggml_mul(ctx0, k, layer.time_mix_k_a); + k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka)); + + r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens); + w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens); + k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens); + v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); + a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); + + ggml_tensor * wkv_state = build_copy_mask_state( + gf, kv_self->v_l[il], state_copy, state_mask, + hparams.n_embd_v_s(), n_seqs); + + ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); + cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); + wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); + + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + wkv_state, + ggml_view_1d( + ctx0, + kv_self->v_l[il], + hparams.n_embd_v_s() * n_seqs, + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + ) + ) + ); + + if (layer.time_mix_ln && layer.time_mix_ln_b) { + // group norm with head_count groups + cur = ggml_reshape_3d(ctx0, cur, n_embd / head_count, head_count, n_tokens); + cur = ggml_norm(ctx0, cur, 64e-5f); + + // Convert back to regular vectors. + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b); + } else { + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + } + + ggml_tensor * rk = ggml_sum_rows(ctx0, + ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count))); + cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens)); + + if (has_gating) { + cur = ggml_mul(ctx0, cur, g); + } + cur = build_lora_mm(layer.time_mix_output, cur); + + return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs); + } +}; + +struct llm_build_rwkv7 : public llm_build_rwkv7_base { + llm_build_rwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) { + GGML_ASSERT(hparams.token_shift_count == 2); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * v_first = nullptr; + + inpL = build_inp_embd(model.tok_embd); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + + ggml_tensor * state_copy = build_inp_s_copy(); + ggml_tensor * state_mask = build_inp_s_mask(); + + const auto n_embd = hparams.n_embd; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_seqs = ubatch.n_seqs; + + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); + + ggml_tensor * token_shift = build_rwkv_token_shift_load( + gf, state_copy, state_mask, ubatch, il + ); + + ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); + ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); + + ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il); + cb(att_norm, "attn_norm", il); + + ggml_tensor * x_prev = ggml_concat( + ctx0, + att_shift, + ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0), + 1 + ); + + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il); + cb(ffn_norm, "ffn_norm", il); + + x_prev = ggml_concat( + ctx0, + ffn_shift, + ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0), + 1 + ); + + token_shift = ggml_concat(ctx0, + ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)), + ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)), + 1 + ); + ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids); + ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids); + x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids); + } + + cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + + +struct llm_build_arwkv7 : public llm_build_rwkv7_base { + llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) { + GGML_ASSERT(n_embd == hparams.n_embd_k_s()); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * v_first = nullptr; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * state_copy = build_inp_s_copy(); + ggml_tensor * state_mask = build_inp_s_mask(); + + const auto n_embd = hparams.n_embd; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_seqs = ubatch.n_seqs; + + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); + + ggml_tensor * token_shift = build_rwkv_token_shift_load( + gf, state_copy, state_mask, ubatch, il + ); + + ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); + cb(att_norm, "attn_norm", il); + + ggml_tensor * x_prev = ggml_concat( + ctx0, + token_shift, + ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0), + 1 + ); + + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il); + + token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); + ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids); + } + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// ref: https://github.com/facebookresearch/chameleon +// based on the original build_llama() function, changes: +// * qk-norm +// * swin-norm +// * removed bias +// * removed MoE +struct llm_build_chameleon : public llm_graph_context { + llm_build_chameleon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + if (hparams.swin_norm) { + cur = inpL; + } else { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + } + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + if (model.layers[il].attn_q_norm) { + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur) * n_embd_head, + ggml_element_size(Qcur) * n_embd_head * n_head, + 0); + cb(Qcur, "Qcur", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, il); + cb(Qcur, "Qcur", il); + } + + if (model.layers[il].attn_k_norm) { + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, + ggml_element_size(Kcur) * n_embd_head, + ggml_element_size(Kcur) * n_embd_head * n_head_kv, + 0); + cb(Kcur, "Kcur", il); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + + if (hparams.swin_norm) { + cur = build_norm(cur, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + } + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (!hparams.swin_norm) { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + } + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + if (hparams.swin_norm) { + cur = build_norm(cur, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output_with_img_logits", -1); + + // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. + // Needs to be removed once image outputs are supported. + int img_token_end_idx = 8196; + int img_token_start_idx = 4; + int num_img_tokens = img_token_end_idx - img_token_start_idx; + // creates 1d tensor of size num_img_tokens and values -FLT_MAX, + // which ensures that text token values are always at least larger than image token values + ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens); + img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX); + cb(img_logits, "img_logits", -1); + + cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_solar : public llm_graph_context { + llm_build_solar(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + struct ggml_tensor * bskcn_1; + struct ggml_tensor * bskcn_2; + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + if (hparams.n_bskcn(0, il)) { + bskcn_1 = inpSA; + } + + if (hparams.n_bskcn(1, il)) { + bskcn_2 = inpSA; + } + + if (hparams.n_bskcn(2, il)) { + inpSA = ggml_add( + ctx0, + ggml_mul(ctx0, bskcn_1, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)), + ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv)))); + } + + if (hparams.n_bskcn(3, il)) { + inpSA = ggml_add( + ctx0, + ggml_mul(ctx0, bskcn_2, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)), + ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv)))); + } + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_wavtokenizer_dec : public llm_graph_context { + llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL)); + + cur = ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1); + cur = ggml_add(ctx0, cur, model.conv1d_b); + + // posnet + for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) { + const auto & layer = model.layers[il].posnet; + + inpL = cur; + + switch (il) { + case 0: + case 1: + case 3: + case 4: + { + cur = build_norm(cur, + layer.norm1, + layer.norm1_b, + LLM_NORM_GROUP, 0); + + cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur); + + cur = ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1); + cur = ggml_add(ctx0, cur, layer.conv1_b); + + cur = build_norm(cur, + layer.norm2, + layer.norm2_b, + LLM_NORM_GROUP, 0); + + cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur); + + cur = ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1); + cur = ggml_add(ctx0, cur, layer.conv2_b); + + cur = ggml_add(ctx0, cur, inpL); + } break; + case 2: + { + cur = build_norm(cur, + layer.attn_norm, + layer.attn_norm_b, + LLM_NORM_GROUP, 0); + + ggml_tensor * q; + ggml_tensor * k; + ggml_tensor * v; + + q = ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1); + k = ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1); + v = ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1); + + q = ggml_add(ctx0, q, layer.attn_q_b); + k = ggml_add(ctx0, k, layer.attn_k_b); + v = ggml_add(ctx0, v, layer.attn_v_b); + + q = ggml_cont(ctx0, ggml_transpose(ctx0, q)); + k = ggml_cont(ctx0, ggml_transpose(ctx0, k)); + + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + + kq = ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f); + + cur = ggml_mul_mat(ctx0, kq, v); + + cur = ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1); + cur = ggml_add(ctx0, cur, layer.attn_o_b); + + cur = ggml_add(ctx0, cur, inpL); + } break; + case 5: + { + cur = build_norm(cur, + layer.norm, + layer.norm_b, + LLM_NORM_GROUP, 0); + } break; + default: GGML_ABORT("unknown posnet layer"); + }; + } + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + cur = build_norm(cur, + model.tok_norm, + model.tok_norm_b, + LLM_NORM, -1); + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + inpL = cur; + + // convnext + for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) { + const auto & layer = model.layers[il].convnext; + + cur = inpL; + + cur = ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1); + cur = ggml_add(ctx0, cur, layer.dw_b); + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + cur = build_norm(cur, + layer.norm, + layer.norm_b, + LLM_NORM, -1); + + cur = build_ffn(cur, + layer.pw1, layer.pw1_b, NULL, + NULL, NULL, NULL, + layer.pw2, layer.pw2_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + + cur = ggml_mul(ctx0, cur, layer.gamma); + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + inpL = ggml_add(ctx0, cur, inpL); + } + + cur = inpL; + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + cur = build_norm(cur, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + + cur = ggml_add(ctx0, cur, model.output_b); + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_plm : public llm_graph_context { + llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k)); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * q = NULL; + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + 0); + cb(q_nope, "q_nope", il); + + // and {n_head * n_embd_head_qk_rope, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_pe_compresseed, "kv_pe_compresseed", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, + kv_pe_compresseed->nb[1], + 0); + cb(kv_compressed, "kv_compressed", il); + + // and {n_embd_head_qk_rope, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + kv_compressed = build_norm(kv_compressed, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, il); + cb(kv_compressed, "kv_compressed", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + q_states, k_states, v_states, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_bailingmoe : public llm_graph_context { + llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + false, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +llama_memory_i * llama_model::create_memory() const { + llama_memory_i * res; + + switch (arch) { + case LLM_ARCH_MAMBA: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + { + res = new llama_kv_cache_unified(hparams, { + /*.get_rope_factors =*/ nullptr + }); + } break; + default: + { + res = new llama_kv_cache_unified(hparams, { + /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) { + // choose long/short freq factors based on the context size + if (layers[il].rope_freqs != nullptr) { + return layers[il].rope_freqs; + } + + if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { + return layers[il].rope_long; + } + + return layers[il].rope_short; + } + }); + } + } + + return res; +} + +llm_graph_result_ptr llama_model::build_graph( + const llm_graph_params & params, + ggml_cgraph * gf, + llm_graph_type type) const { + std::unique_ptr llm; + + switch (arch) { + case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: + case LLM_ARCH_MINICPM: + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_MLLAMA: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_DECI: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BAICHUAN: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_FALCON: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GROK: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_STARCODER: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_REFACT: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_NOMIC_BERT: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BLOOM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_MPT: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_STABLELM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN2VL: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN2MOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN3: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN3MOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_PHI2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_PHI3: + case LLM_ARCH_PHIMOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_PLAMO: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GPT2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_CODESHELL: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_ORION: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_INTERNLM2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_MINICPM3: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GEMMA: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GEMMA2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GEMMA3: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_STARCODER2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_MAMBA: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_XVERSE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_COMMAND_R: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_COHERE2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_DBRX: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_OLMO: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_OLMO2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_OLMOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_OPENELM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GPTNEOX: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_ARCTIC: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_DEEPSEEK: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_DEEPSEEK2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_CHATGLM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_GLM4: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BITNET: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_T5: + { + switch (type) { + case LLM_GRAPH_TYPE_ENCODER: + llm = std::make_unique(*this, params, gf); + break; + case LLM_GRAPH_TYPE_DEFAULT: + case LLM_GRAPH_TYPE_DECODER: + llm = std::make_unique(*this, params, gf); + break; + default: + GGML_ABORT("invalid graph type"); + }; + } break; + case LLM_ARCH_T5ENCODER: + { + llm = std::make_unique(*this, params, gf); + } + break; + case LLM_ARCH_JAIS: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_NEMOTRON: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_EXAONE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_RWKV6: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_RWKV6QWEN2: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_RWKV7: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_ARWKV7: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_CHAMELEON: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_SOLAR: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_WAVTOKENIZER_DEC: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_PLM: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_BAILINGMOE: + { + llm = std::make_unique(*this, params, gf); + } break; + default: + GGML_ABORT("fatal error"); + } + + // add on pooling layer + llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b); + + return std::move(llm->res); +} + // // interface implementation // -struct llama_model_params llama_model_default_params() { - struct llama_model_params result = { +llama_model_params llama_model_default_params() { + llama_model_params result = { /*.devices =*/ nullptr, + /*.tensor_buft_overrides =*/ nullptr, /*.n_gpu_layers =*/ 0, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, @@ -3925,59 +13555,59 @@ struct llama_model_params llama_model_default_params() { return result; } -const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model) { +const llama_vocab * llama_model_get_vocab(const llama_model * model) { return &model->vocab; } -void llama_free_model(struct llama_model * model) { +void llama_free_model(llama_model * model) { llama_model_free(model); } -void llama_model_free(struct llama_model * model) { +void llama_model_free(llama_model * model) { delete model; } -int32_t llama_model_n_ctx_train(const struct llama_model * model) { +int32_t llama_model_n_ctx_train(const llama_model * model) { return model->hparams.n_ctx_train; } -int32_t llama_model_n_embd(const struct llama_model * model) { +int32_t llama_model_n_embd(const llama_model * model) { return model->hparams.n_embd; } -int32_t llama_model_n_layer(const struct llama_model * model) { +int32_t llama_model_n_layer(const llama_model * model) { return model->hparams.n_layer; } -int32_t llama_model_n_head(const struct llama_model * model) { +int32_t llama_model_n_head(const llama_model * model) { return model->hparams.n_head(); } -int32_t llama_model_n_head_kv(const struct llama_model * model) { +int32_t llama_model_n_head_kv(const llama_model * model) { return model->hparams.n_head_kv(); } // deprecated -int32_t llama_n_ctx_train(const struct llama_model * model) { +int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); } // deprecated -int32_t llama_n_embd(const struct llama_model * model) { +int32_t llama_n_embd(const llama_model * model) { return llama_model_n_embd(model); } // deprecated -int32_t llama_n_layer(const struct llama_model * model) { +int32_t llama_n_layer(const llama_model * model) { return llama_model_n_layer(model); } // deprecated -int32_t llama_n_head(const struct llama_model * model) { +int32_t llama_n_head(const llama_model * model) { return llama_model_n_head(model); } -enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { +llama_rope_type llama_model_rope_type(const llama_model * model) { switch (model->arch) { // these models do not use RoPE case LLM_ARCH_GPT2: @@ -3992,11 +13622,14 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_JAIS: case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: case LLM_ARCH_MLLAMA: case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: @@ -4012,11 +13645,14 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: + case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_CHAMELEON: case LLM_ARCH_SOLAR: + case LLM_ARCH_BAILINGMOE: case LLM_ARCH_MISTRAL3: return LLAMA_ROPE_TYPE_NORM; @@ -4031,6 +13667,8 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: + case LLM_ARCH_QWEN3: + case LLM_ARCH_QWEN3MOE: case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_PHI2: @@ -4059,11 +13697,11 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { return LLAMA_ROPE_TYPE_NONE; } -float llama_model_rope_freq_scale_train(const struct llama_model * model) { +float llama_model_rope_freq_scale_train(const llama_model * model) { return model->hparams.rope_freq_scale_train; } -int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) { +int32_t llama_model_meta_val_str(const llama_model * model, const char * key, char * buf, size_t buf_size) { const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { if (buf_size > 0) { @@ -4074,11 +13712,11 @@ int32_t llama_model_meta_val_str(const struct llama_model * model, const char * return snprintf(buf, buf_size, "%s", it->second.c_str()); } -int32_t llama_model_meta_count(const struct llama_model * model) { +int32_t llama_model_meta_count(const llama_model * model) { return (int)model->gguf_kv.size(); } -int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) { +int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) { if (i < 0 || i >= (int)model->gguf_kv.size()) { if (buf_size > 0) { buf[0] = '\0'; @@ -4090,7 +13728,7 @@ int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, c return snprintf(buf, buf_size, "%s", it->first.c_str()); } -int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) { +int32_t llama_model_meta_val_str_by_index(const llama_model * model, int32_t i, char * buf, size_t buf_size) { if (i < 0 || i >= (int)model->gguf_kv.size()) { if (buf_size > 0) { buf[0] = '\0'; @@ -4102,15 +13740,15 @@ int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int3 return snprintf(buf, buf_size, "%s", it->second.c_str()); } -int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { +int32_t llama_model_desc(const llama_model * model, char * buf, size_t buf_size) { return snprintf(buf, buf_size, "%s", model->desc().c_str()); } -uint64_t llama_model_size(const struct llama_model * model) { +uint64_t llama_model_size(const llama_model * model) { return model->size(); } -const char * llama_model_chat_template(const struct llama_model * model, const char * name) { +const char * llama_model_chat_template(const llama_model * model, const char * name) { const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); const auto & it = model->gguf_kv.find(key); @@ -4121,11 +13759,11 @@ const char * llama_model_chat_template(const struct llama_model * model, const c return it->second.c_str(); } -uint64_t llama_model_n_params(const struct llama_model * model) { +uint64_t llama_model_n_params(const llama_model * model) { return model->n_elements(); } -bool llama_model_has_encoder(const struct llama_model * model) { +bool llama_model_has_encoder(const llama_model * model) { switch (model->arch) { case LLM_ARCH_T5: return true; case LLM_ARCH_T5ENCODER: return true; @@ -4133,22 +13771,28 @@ bool llama_model_has_encoder(const struct llama_model * model) { } } -bool llama_model_has_decoder(const struct llama_model * model) { +bool llama_model_has_decoder(const llama_model * model) { switch (model->arch) { case LLM_ARCH_T5ENCODER: return false; default: return true; } } -llama_token llama_model_decoder_start_token(const struct llama_model * model) { +llama_token llama_model_decoder_start_token(const llama_model * model) { return model->hparams.dec_start_token_id; } -bool llama_model_is_recurrent(const struct llama_model * model) { +bool llama_model_is_recurrent(const llama_model * model) { switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - case LLM_ARCH_RWKV6QWEN2: return true; - default: return false; + case LLM_ARCH_MAMBA: return true; + case LLM_ARCH_RWKV6: return true; + case LLM_ARCH_RWKV6QWEN2: return true; + case LLM_ARCH_RWKV7: return true; + case LLM_ARCH_ARWKV7: return true; + default: return false; } } + +const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { + return model->tensors_by_name; +} diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h index 7cf575872..21c4617bb 100644 --- a/llama/llama.cpp/src/llama-model.h +++ b/llama/llama.cpp/src/llama-model.h @@ -2,7 +2,9 @@ #include "llama.h" #include "llama-arch.h" +#include "llama-graph.h" #include "llama-hparams.h" +#include "llama-memory.h" #include "llama-vocab.h" #include @@ -11,6 +13,8 @@ #include #include +struct llama_cparams; +struct llama_ubatch; struct llama_model_loader; // available models @@ -26,6 +30,7 @@ enum llm_type { LLM_TYPE_109M, LLM_TYPE_137M, LLM_TYPE_160M, + LLM_TYPE_190M, LLM_TYPE_220M, LLM_TYPE_250M, LLM_TYPE_270M, @@ -40,8 +45,10 @@ enum llm_type { LLM_TYPE_1_4B, LLM_TYPE_1_5B, LLM_TYPE_1_6B, + LLM_TYPE_1_8B, LLM_TYPE_2B, LLM_TYPE_2_8B, + LLM_TYPE_2_9B, LLM_TYPE_3B, LLM_TYPE_4B, LLM_TYPE_6B, @@ -81,6 +88,9 @@ enum llm_type { LLM_TYPE_10B_128x3_66B, LLM_TYPE_57B_A14B, LLM_TYPE_27B, + LLM_TYPE_290B, + LLM_TYPE_17B_16E, // llama4 Scout + LLM_TYPE_17B_128E, // llama4 Maverick }; struct llama_layer_posnet { @@ -259,6 +269,20 @@ struct llama_layer { struct ggml_tensor * time_mix_receptance_b = nullptr; struct ggml_tensor * time_mix_gate = nullptr; + // rwkv7 + struct ggml_tensor * time_mix_w0 = nullptr; + struct ggml_tensor * time_mix_a0 = nullptr; + struct ggml_tensor * time_mix_a1 = nullptr; + struct ggml_tensor * time_mix_a2 = nullptr; + struct ggml_tensor * time_mix_v0 = nullptr; + struct ggml_tensor * time_mix_v1 = nullptr; + struct ggml_tensor * time_mix_v2 = nullptr; + struct ggml_tensor * time_mix_g1 = nullptr; + struct ggml_tensor * time_mix_g2 = nullptr; + struct ggml_tensor * time_mix_k_k = nullptr; + struct ggml_tensor * time_mix_k_a = nullptr; + struct ggml_tensor * time_mix_r_k = nullptr; + struct ggml_tensor * time_mix_ln = nullptr; struct ggml_tensor * time_mix_ln_b = nullptr; struct ggml_tensor * time_mix_output = nullptr; @@ -362,7 +386,7 @@ struct llama_model { std::string desc() const; size_t size() const; - size_t max_nodes() const; + size_t n_tensors() const; size_t n_devices() const; // total number of parameters in the model @@ -375,11 +399,26 @@ struct llama_model { ggml_backend_buffer_type_t select_buft(int il) const; + bool has_tensor_overrides() const; + const struct ggml_tensor * get_tensor(const char * name) const; + // TODO: move this to new llm_arch_model_i interface + llama_memory_i * create_memory() const; // TODO: params + + // TODO: move this to new llm_arch_model_i interface + llm_graph_result_ptr build_graph( + const llm_graph_params & params, + ggml_cgraph * gf, + llm_graph_type type) const; + private: struct impl; std::unique_ptr pimpl; }; const char * llm_type_name(llm_type type); + +// For internal test use +// TODO: remove +const std::vector> & llama_internal_get_tensor_map(const llama_model * model); diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index ebcbafa1c..8ae6dde87 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -47,8 +48,14 @@ struct quantize_state_impl { {} }; +// changes to this struct must be replicated in quantize.cpp +struct tensor_quantization { + std::string name; + ggml_type quant = GGML_TYPE_COUNT; +}; + static void llama_tensor_dequantize_impl( - struct ggml_tensor * tensor, std::vector> & output, std::vector & workers, + ggml_tensor * tensor, std::vector> & output, std::vector & workers, const size_t nelements, const int nthread ) { if (output.size() < nelements) { @@ -527,7 +534,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides); + llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); @@ -536,7 +543,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: model.load_hparams(ml); model.load_stats (ml); - struct quantize_state_impl qs(model, params); + quantize_state_impl qs(model, params); if (params->only_copy) { ftype = ml.ftype; @@ -663,7 +670,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // populate the original tensors so we get an initial meta data for (const auto * it : tensors) { uint16_t i_split = params->keep_split ? it->idx : 0; - struct ggml_tensor * tensor = it->tensor; + ggml_tensor * tensor = it->tensor; if (!ctx_outs[i_split]) { ctx_outs[i_split].reset(gguf_init_empty()); } @@ -712,7 +719,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: new_ofstream(0); for (const auto * it : tensors) { const auto & weight = *it; - struct ggml_tensor * tensor = weight.tensor; + ggml_tensor * tensor = weight.tensor; if (weight.idx != cur_split && params->keep_split) { close_ofstream(); new_ofstream(weight.idx); @@ -762,10 +769,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // NOTE: can't use LLM_TN here because the layer number is not known quantize &= name.find("ssm_conv1d.weight") == std::string::npos; - // do not quantize RWKV's time_mix_first tensors + // do not quantize RWKV's small yet 2D weights quantize &= name.find("time_mix_first.weight") == std::string::npos; + quantize &= name.find("time_mix_w0.weight") == std::string::npos; quantize &= name.find("time_mix_w1.weight") == std::string::npos; quantize &= name.find("time_mix_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_v0.weight") == std::string::npos; + quantize &= name.find("time_mix_v1.weight") == std::string::npos; + quantize &= name.find("time_mix_v2.weight") == std::string::npos; + quantize &= name.find("time_mix_a0.weight") == std::string::npos; + quantize &= name.find("time_mix_a1.weight") == std::string::npos; + quantize &= name.find("time_mix_a2.weight") == std::string::npos; + quantize &= name.find("time_mix_g1.weight") == std::string::npos; + quantize &= name.find("time_mix_g2.weight") == std::string::npos; quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; @@ -773,7 +789,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // do not quantize relative position bias (T5) quantize &= name.find("attn_rel_b.weight") == std::string::npos; - enum ggml_type new_type; + ggml_type new_type; void * new_data; size_t new_size; @@ -783,6 +799,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + // unless the user specifies a type + if (params->tensor_types) { + const std::vector & tensor_types = *static_cast *>(params->tensor_types); + for (const auto & [tname, qtype] : tensor_types) { + if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) { + if (qtype != new_type) { + LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype)); + } + new_type = qtype; + break; + } + } + } } if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; @@ -907,8 +936,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // interface implementation // -struct llama_model_quantize_params llama_model_quantize_default_params() { - struct llama_model_quantize_params result = { +llama_model_quantize_params llama_model_quantize_default_params() { + llama_model_quantize_params result = { /*.nthread =*/ 0, /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, /*.output_tensor_type =*/ GGML_TYPE_COUNT, @@ -920,6 +949,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.keep_split =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.tensor_type =*/ nullptr, }; return result; diff --git a/llama/llama.cpp/src/llama-sampling.cpp b/llama/llama.cpp/src/llama-sampling.cpp index f40bf2db8..d14979850 100644 --- a/llama/llama.cpp/src/llama-sampling.cpp +++ b/llama/llama.cpp/src/llama-sampling.cpp @@ -1449,7 +1449,9 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, - size_t num_trigger_tokens); + size_t num_trigger_tokens, + const char ** trigger_patterns, + size_t num_trigger_patterns); static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_grammar *) smpl->ctx; @@ -1457,12 +1459,14 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { return; } - std::vector trigger_words; - for (auto & word : ctx->grammar->trigger_words) { - trigger_words.push_back(word.c_str()); + std::vector trigger_patterns_c; + trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size()); + for (auto & trigger_pattern : ctx->grammar->trigger_patterns) { + trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); } + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), - ctx->grammar->lazy, trigger_words.data(), trigger_words.size(), + ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); @@ -1472,7 +1476,8 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; - auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0); + auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0); + GGML_ASSERT(result); // copy the state { @@ -1516,16 +1521,38 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, - size_t num_trigger_tokens) { + size_t num_trigger_tokens, + const char ** trigger_patterns, + size_t num_trigger_patterns) { auto * ctx = new llama_sampler_grammar; if (grammar_str != nullptr && grammar_str[0] != '\0') { + // TODO: remove trigger_words support. + if (trigger_words != nullptr && num_trigger_words > 0) { + GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0); + std::string trigger_pattern("[\\s\\S]*?("); + for (size_t i = 0; i < num_trigger_words; ++i) { + static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); + if (i > 0) { + trigger_pattern += "|"; + } + trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0"); + } + trigger_pattern += ")[\\s\\S]*"; + auto trigger_pattern_c = trigger_pattern.c_str(); + trigger_patterns = &trigger_pattern_c; + num_trigger_patterns = 1; + } *ctx = { /* .vocab = */ vocab, /* .grammar_str = */ grammar_str, /* .grammar_root = */ grammar_root, - /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), + /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens), }; + if (!ctx->grammar) { + delete ctx; + return nullptr; + } } else { *ctx = { /* .vocab = */ vocab, @@ -1545,7 +1572,7 @@ struct llama_sampler * llama_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { - return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0); + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0); } struct llama_sampler * llama_sampler_init_grammar_lazy( @@ -1556,7 +1583,18 @@ struct llama_sampler * llama_sampler_init_grammar_lazy( size_t num_trigger_words, const llama_token * trigger_tokens, size_t num_trigger_tokens) { - return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens); + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0); +} + +struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns); } // penalties diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp index 7a185443a..c90f636ca 100644 --- a/llama/llama.cpp/src/llama-vocab.cpp +++ b/llama/llama.cpp/src/llama-vocab.cpp @@ -16,6 +16,7 @@ #include #include #include +#include // // helpers @@ -341,6 +342,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_MPT: case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: + case LLAMA_VOCAB_PRE_TYPE_TRILLION: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -393,10 +395,24 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; break; case LLAMA_VOCAB_PRE_TYPE_GPT4O: - // original regex from tokenizer.json - // [^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+ regex_exprs = { - "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + // original regex from tokenizer.json + // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_SUPERBPE: + regex_exprs = { + "\\p{N}+", + "(?=(\\d{3})+(?!\\d))", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE: + regex_exprs = { + // original regex from tokenizer.json + // "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" + // FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?) + "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", }; break; default: @@ -1547,6 +1563,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_PORO; clean_spaces = false; } else if ( + tokenizer_pre == "glm4" || tokenizer_pre == "chatglm-bpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; special_bos_id = LLAMA_TOKEN_NULL; @@ -1591,9 +1608,22 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "megrez") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; } else if ( - tokenizer_pre == "gpt-4o") { + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "superbpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; + clean_spaces = false; + } else if ( + tokenizer_pre == "trillion") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; + clean_spaces = false; + } else if ( + tokenizer_pre == "bailingmoe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; + clean_spaces = false; } else { LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -1772,6 +1802,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" || t.first == "<|endoftext|>" || t.first == "" + || t.first == "_" || t.first == "<|end▁of▁sentence|>" // DeepSeek ) { special_eot_id = t.second; @@ -1804,6 +1835,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" || t.first == "<|fim▁begin|>" // DeepSeek || t.first == "
"
+                        || t.first == "▁
"          // CodeLlama
                         ) {
                     special_fim_pre_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1821,6 +1853,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
+                        || t.first == "▁"         // CodeLlama
                         ) {
                     special_fim_suf_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1838,6 +1871,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
+                        || t.first == "▁"         // CodeLlama
                         ) {
                     special_fim_mid_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1922,6 +1956,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|endoftext|>"
                     || t.first == "<|eom_id|>"
                     || t.first == ""
+                    || t.first == "_"
                ) {
                 special_eog_ids.insert(t.second);
                 if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2180,14 +2215,12 @@ void llama_vocab::impl::tokenizer_st_partition(std::forward_list length
-                    if (match + text.length() > raw_text_base_offset + raw_text_base_length) break;
-
 #ifdef PRETOKENIZERDEBUG
                     LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
 #endif
diff --git a/llama/llama.cpp/src/llama.cpp b/llama/llama.cpp/src/llama.cpp
index 01854fcef..d5164720b 100644
--- a/llama/llama.cpp/src/llama.cpp
+++ b/llama/llama.cpp/src/llama.cpp
@@ -2,9769 +2,28 @@
 
 #include "llama-chat.h"
 #include "llama-mmap.h"
-#include "llama-context.h"
 #include "llama-vocab.h"
-#include "llama-sampling.h"
-#include "llama-kv-cache.h"
 #include "llama-model-loader.h"
 #include "llama-model.h"
 
 #include "ggml.h"
-#include "ggml-alloc.h"
 #include "ggml-backend.h"
-#include "ggml-cpp.h"
 
 #include 
-#include 
-#include 
-#include 
-#include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
-static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) {
-    // loading time will be recalculated after the first eval, so
-    // we take page faults deferred by mmap() into consideration
-    model.t_load_us = 0;
-    time_meas tm(model.t_load_us);
-
-    model.t_start_us = tm.t_start_us;
-
-    try {
-        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
-
-        ml.print_info();
-
-        model.hparams.vocab_only = params.vocab_only;
-
-        try {
-            model.load_arch(ml);
-        } catch(const std::exception & e) {
-            throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
-        }
-        try {
-            model.load_hparams(ml);
-        } catch(const std::exception & e) {
-            throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
-        }
-        try {
-            model.load_vocab(ml);
-        } catch(const std::exception & e) {
-            throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
-        }
-
-        model.load_stats(ml);
-        model.print_info();
-
-        if (params.vocab_only) {
-            LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
-            return 0;
-        }
-
-        if (!model.load_tensors(ml)) {
-            return -2;
-        }
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
-        return -1;
-    }
-
-    return 0;
-}
-
-//
-// llm_build
-//
-
-using llm_build_cb = std::function;
-
-enum llm_ffn_op_type {
-    LLM_FFN_SILU,
-    LLM_FFN_GELU,
-    LLM_FFN_RELU,
-    LLM_FFN_RELU_SQR,
-    LLM_FFN_SWIGLU,
-};
-
-enum llm_ffn_gate_type {
-    LLM_FFN_SEQ,
-    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
-};
-
-enum llm_norm_type {
-    LLM_NORM,
-    LLM_NORM_RMS,
-    LLM_NORM_GROUP,
-};
-
-static struct ggml_tensor * llm_build_inp_embd(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-        const llama_hparams & hparams,
-         const llama_ubatch & ubatch,
-         struct ggml_tensor * tok_embd,
-         const llm_build_cb & cb) {
-    const int64_t n_embd = hparams.n_embd;
-
-    struct ggml_tensor * inpL;
-
-    if (ubatch.token) {
-        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
-        cb(lctx.inp_tokens, "inp_tokens", -1);
-        ggml_set_input(lctx.inp_tokens);
-
-        inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
-
-        // apply lora for embedding tokens if needed
-        for (auto & it : lctx.lora) {
-            struct llama_adapter_lora_weight * lw = it.first->get_weight(tok_embd);
-            if (lw == nullptr) {
-                continue;
-            }
-            const float adapter_scale = it.second;
-            const float scale = lw->get_scale(it.first->alpha, adapter_scale);
-            struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
-                ctx, lw->b, // non-transposed lora_b
-                ggml_get_rows(ctx, lw->a, lctx.inp_tokens)
-            ), scale);
-            inpL = ggml_add(ctx, inpL, inpL_delta);
-        }
-    } else {
-        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
-        inpL = lctx.inp_embd;
-        ggml_set_input(lctx.inp_embd);
-    }
-
-    // For Granite architecture
-    if (hparams.f_embedding_scale != 0.0f) {
-        inpL = ggml_scale(ctx, inpL, hparams.f_embedding_scale);
-    }
-
-    cb(inpL, "inp_embd", -1);
-
-    return inpL;
-}
-
-static struct ggml_tensor * llm_build_inp_cross_attn_state(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-        const llama_hparams & hparams,
-         const llm_build_cb & cb) {
-    const int64_t n_embd = hparams.n_embd;
-
-    struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
-    cb(inpCAS, "inp_cross_attn_state", -1);
-    ggml_set_input(inpCAS);
-    lctx.inp_cross_attn_state = inpCAS;
-
-    return inpCAS;
-}
-
-static void llm_build_kv_store(
-        struct ggml_context * ctx,
-        const llama_hparams & hparams,
-        const llama_cparams & cparams,
-       const llama_kv_cache & kv,
-         struct ggml_cgraph * graph,
-         struct ggml_tensor * k_cur,
-         struct ggml_tensor * v_cur,
-                    int32_t   n_tokens,
-                    int32_t   kv_head,
-         const llm_build_cb & cb,
-                    int64_t   il) {
-    const int64_t n_ctx = cparams.n_ctx;
-
-    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-    GGML_ASSERT(kv.size == n_ctx);
-
-    struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa)*kv_head);
-    cb(k_cache_view, "k_cache_view", il);
-
-    // note: storing RoPE-ed version of K in the KV cache
-    ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
-
-    assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
-
-    struct ggml_tensor * v_cache_view = nullptr;
-
-    if (cparams.flash_attn) {
-        v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)*kv_head);
-    } else {
-        // note: the V cache is transposed when not using flash attention
-        v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
-                (  n_ctx)*ggml_element_size(kv.v_l[il]),
-                (kv_head)*ggml_element_size(kv.v_l[il]));
-
-        v_cur = ggml_transpose(ctx, v_cur);
-    }
-    cb(v_cache_view, "v_cache_view", il);
-
-    ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
-}
-
-// do mat_mul, while optionally apply lora
-static struct ggml_tensor * llm_build_lora_mm(
-        struct llama_context & lctx,
-         struct ggml_context * ctx0,
-          struct ggml_tensor * w,
-          struct ggml_tensor * cur) {
-    struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
-    for (auto & it : lctx.lora) {
-        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
-        if (lw == nullptr) {
-            continue;
-        }
-        const float adapter_scale = it.second;
-        const float scale = lw->get_scale(it.first->alpha, adapter_scale);
-        struct ggml_tensor * ab_cur = ggml_mul_mat(
-            ctx0, lw->b,
-            ggml_mul_mat(ctx0, lw->a, cur)
-        );
-        ab_cur = ggml_scale(ctx0, ab_cur, scale);
-        res = ggml_add(ctx0, res, ab_cur);
-    }
-    return res;
-}
-
-// do mat_mul_id, while optionally apply lora
-static struct ggml_tensor * llm_build_lora_mm_id(
-        struct llama_context & lctx,
-         struct ggml_context * ctx0,
-          struct ggml_tensor * w,   // struct ggml_tensor * as
-          struct ggml_tensor * cur, // struct ggml_tensor * b
-          struct ggml_tensor * ids) {
-    struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
-    for (auto & it : lctx.lora) {
-        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
-        if (lw == nullptr) {
-            continue;
-        }
-        const float alpha = it.first->alpha;
-        const float rank  = (float) lw->b->ne[0];
-        const float scale = alpha ? it.second * alpha / rank : it.second;
-        struct ggml_tensor * ab_cur = ggml_mul_mat_id(
-            ctx0, lw->b,
-            ggml_mul_mat_id(ctx0, lw->a, cur, ids),
-            ids
-        );
-        ab_cur = ggml_scale(ctx0, ab_cur, scale);
-        res = ggml_add(ctx0, res, ab_cur);
-    }
-    return res;
-}
-
-static struct ggml_tensor * llm_build_norm(
-        struct ggml_context * ctx,
-         struct ggml_tensor * cur,
-        const llama_hparams & hparams,
-         struct ggml_tensor * mw,
-         struct ggml_tensor * mb,
-              llm_norm_type   type,
-         const llm_build_cb & cb,
-                        int   il) {
-    switch (type) {
-        case LLM_NORM:       cur = ggml_norm      (ctx, cur, hparams.f_norm_eps);     break;
-        case LLM_NORM_RMS:   cur = ggml_rms_norm  (ctx, cur, hparams.f_norm_rms_eps); break;
-        case LLM_NORM_GROUP:
-            {
-                cur = ggml_reshape_3d(ctx, cur, cur->ne[0], 1, cur->ne[1]);
-                cur = ggml_group_norm(ctx, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
-                cur = ggml_reshape_2d(ctx, cur, cur->ne[0],    cur->ne[2]);
-            } break;
-    }
-
-    if (mw || mb) {
-        cb(cur, "norm", il);
-    }
-
-    if (mw) {
-        cur = ggml_mul(ctx, cur, mw);
-        if (mb) {
-            cb(cur, "norm_w", il);
-        }
-    }
-
-    if (mb) {
-        cur = ggml_add(ctx, cur, mb);
-    }
-
-    return cur;
-}
-
-static struct ggml_tensor * llm_build_ffn(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-         struct ggml_tensor * cur,
-         struct ggml_tensor * up,
-         struct ggml_tensor * up_b,
-         struct ggml_tensor * up_s,
-         struct ggml_tensor * gate,
-         struct ggml_tensor * gate_b,
-         struct ggml_tensor * gate_s,
-         struct ggml_tensor * down,
-         struct ggml_tensor * down_b,
-         struct ggml_tensor * down_s,
-         struct ggml_tensor * act_scales,
-            llm_ffn_op_type   type_op,
-          llm_ffn_gate_type   type_gate,
-         const llm_build_cb & cb,
-                        int   il) {
-    struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
-    cb(tmp, "ffn_up", il);
-
-    if (up_b) {
-        tmp = ggml_add(ctx, tmp, up_b);
-        cb(tmp, "ffn_up_b", il);
-    }
-
-    if (up_s) {
-        tmp = ggml_mul(ctx, tmp, up_s);
-        cb(tmp, "ffn_up_s", il);
-    }
-
-    if (gate) {
-        switch (type_gate) {
-            case LLM_FFN_SEQ:
-                {
-                    cur = llm_build_lora_mm(lctx, ctx, gate, tmp);
-                    cb(cur, "ffn_gate", il);
-                } break;
-            case LLM_FFN_PAR:
-                {
-                    cur = llm_build_lora_mm(lctx, ctx, gate, cur);
-                    cb(cur, "ffn_gate", il);
-                } break;
-        }
-
-        if (gate_b) {
-            cur = ggml_add(ctx, cur, gate_b);
-            cb(cur, "ffn_gate_b", il);
-        }
-
-        if (gate_s) {
-            cur = ggml_mul(ctx, cur, gate_s);
-            cb(cur, "ffn_gate_s", il);
-        }
-
-    } else {
-        cur = tmp;
-    }
-
-    switch (type_op) {
-        case LLM_FFN_SILU:
-            {
-                cur = ggml_silu(ctx, cur);
-                cb(cur, "ffn_silu", il);
-            } break;
-        case LLM_FFN_GELU:
-            {
-                cur = ggml_gelu(ctx, cur);
-                cb(cur, "ffn_gelu", il);
-                if (act_scales != NULL) {
-                    cur = ggml_div(ctx, cur, act_scales);
-                    cb(cur, "ffn_act", il);
-                }
-            } break;
-        case LLM_FFN_RELU:
-            {
-                cur = ggml_relu(ctx, cur);
-                cb(cur, "ffn_relu", il);
-            } break;
-        case LLM_FFN_RELU_SQR:
-            {
-                cur = ggml_relu(ctx, cur);
-                cb(cur, "ffn_relu", il);
-
-                cur = ggml_sqr(ctx, cur);
-                cb(cur, "ffn_sqr(relu)", il);
-            } break;
-        case LLM_FFN_SWIGLU:
-            {
-                // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
-                int64_t split_point = cur->ne[0] / 2;
-                struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
-                struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
-
-                x0 = ggml_silu(ctx, x0);
-                cb(cur, "ffn_silu", il);
-
-                cur = ggml_mul(ctx, x0, x1);
-                cb(cur, "ffn_mul", il);
-            } break;
-    }
-
-    if (type_gate == LLM_FFN_PAR) {
-        cur = ggml_mul(ctx, cur, tmp);
-        cb(cur, "ffn_gate_par", il);
-    }
-
-    if (down) {
-        cur = llm_build_lora_mm(lctx, ctx, down, cur);
-    }
-
-    if (down_b) {
-        cb(cur, "ffn_down", il);
-    }
-
-    if (down_b) {
-        cur = ggml_add(ctx, cur, down_b);
-    }
-
-    if (down_s) {
-        cur = ggml_mul(ctx, cur, down_s);
-        cb(cur, "ffn_down_s", il);
-    }
-
-    return cur;
-}
-
-static struct ggml_tensor * llm_build_moe_ffn(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-         struct ggml_tensor * cur,
-         struct ggml_tensor * gate_inp,
-         struct ggml_tensor * up_exps,
-         struct ggml_tensor * gate_exps,
-         struct ggml_tensor * down_exps,
-         struct ggml_tensor * exp_probs_b,
-                    int64_t   n_expert,
-                    int64_t   n_expert_used,
-            llm_ffn_op_type   type_op,
-                       bool   norm_w,
-                       bool   scale_w,
-                      float   w_scale,
-llama_expert_gating_func_type gating_op,
-         const llm_build_cb & cb,
-                        int   il) {
-    int64_t n_embd = cur->ne[0];
-    int64_t n_tokens = cur->ne[1];
-
-    ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
-    cb(logits, "ffn_moe_logits", il);
-
-    ggml_tensor * probs = nullptr;
-    switch (gating_op) {
-        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
-            {
-                probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
-            } break;
-        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
-            {
-                probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
-            } break;
-        default:
-            GGML_ABORT("fatal error");
-    }
-    cb(probs, "ffn_moe_probs", il);
-
-    // add experts selection bias - introduced in DeepSeek V3
-    // leave probs unbiased as it's later used to get expert weights
-    ggml_tensor * selection_probs = probs;
-    if (exp_probs_b != nullptr) {
-        selection_probs = ggml_add(ctx, probs, exp_probs_b);
-        cb(selection_probs, "ffn_moe_probs_biased", il);
-    }
-
-    // select experts
-    ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
-    cb(selected_experts->src[0], "ffn_moe_argsort", il);
-    cb(selected_experts, "ffn_moe_topk", il);
-
-    ggml_tensor * weights = ggml_get_rows(ctx,
-            ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
-    cb(weights, "ffn_moe_weights", il);
-
-    if (norm_w) {
-        weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
-
-        ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
-        cb(weights_sum, "ffn_moe_weights_sum", il);
-
-        weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
-        cb(weights, "ffn_moe_weights_norm", il);
-
-        weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
-    }
-    if (scale_w) {
-        weights = ggml_scale(ctx, weights, w_scale);
-        cb(weights, "ffn_moe_weights_scaled", il);
-    }
-
-    cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
-    ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
-    cb(up, "ffn_moe_up", il);
-
-    ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
-    cb(gate, "ffn_moe_gate", il);
-
-    switch (type_op) {
-        case LLM_FFN_SILU:
-            {
-                gate = ggml_silu(ctx, gate);
-                cb(gate, "ffn_moe_silu", il);
-            } break;
-        case LLM_FFN_GELU:
-            {
-                gate = ggml_gelu(ctx, gate);
-                cb(gate, "ffn_moe_gelu", il);
-            } break;
-        default:
-            GGML_ABORT("fatal error");
-    }
-
-    ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
-    cb(par, "ffn_moe_gate_par", il);
-
-    ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
-    cb(experts, "ffn_moe_down", il);
-
-    experts = ggml_mul(ctx, experts, weights);
-
-    // aggregate experts
-    ggml_tensor * moe_out = nullptr;
-    for (int i = 0; i < n_expert_used; ++i) {
-        ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
-                experts->nb[2], i*experts->nb[1]);
-
-        if (i == 0) {
-            moe_out = cur_expert;
-        } else {
-            moe_out = ggml_add(ctx, moe_out, cur_expert);
-        }
-    }
-
-    if (n_expert_used == 1) {
-        // avoid returning a non-contiguous tensor
-        moe_out = ggml_cont(ctx, moe_out);
-    }
-
-    return moe_out;
-}
-
-static struct ggml_tensor * llm_build_kqv(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-       const llama_kv_cache & kv,
-         struct ggml_cgraph * graph,
-         struct ggml_tensor * wo,
-         struct ggml_tensor * wo_b,
-         struct ggml_tensor * q_cur,
-         struct ggml_tensor * kq_mask,
-                    int32_t   n_tokens,
-                    int32_t   n_kv,
-                    float     kq_scale,
-         const llm_build_cb & cb,
-                    int       il) {
-    const llama_model   & model   = lctx.model;
-    const llama_hparams & hparams = lctx.model.hparams;
-    const llama_cparams & cparams = lctx.cparams;
-
-    const int64_t n_ctx         = cparams.n_ctx;
-    const int64_t n_head        = hparams.n_head(il);
-    const int64_t n_head_kv     = hparams.n_head_kv(il);
-    const int64_t n_embd_head_k = hparams.n_embd_head_k;
-    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa(il);
-    const int64_t n_embd_head_v = hparams.n_embd_head_v;
-    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa(il);
-
-    struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
-    cb(q, "q", il);
-
-    struct ggml_tensor * k =
-        ggml_view_3d(ctx, kv.k_l[il],
-                n_embd_head_k, n_kv, n_head_kv,
-                ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
-                ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
-                0);
-    cb(k, "k", il);
-
-    struct ggml_tensor * cur;
-
-    if (cparams.flash_attn) {
-        GGML_UNUSED(model);
-        GGML_UNUSED(n_ctx);
-
-        // split cached v into n_head heads (not transposed)
-        struct ggml_tensor * v =
-            ggml_view_3d(ctx, kv.v_l[il],
-                    n_embd_head_v, n_kv, n_head_kv,
-                    ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
-                    ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
-                    0);
-        cb(v, "v", il);
-
-        cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
-                                  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
-
-        ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
-
-        cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
-    } else {
-        struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
-        cb(kq, "kq", il);
-
-        // note: this op tends to require high floating point range
-        //       while for some models F16 is enough, for others it is not, so we default to F32 here
-        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
-
-        if (model.arch == LLM_ARCH_GROK) {
-            // need to do the following:
-            // multiply by attn_output_multiplyer of 0.08838834764831845
-            // and then :
-            // kq = 30 * tanh(kq / 30)
-            // before the softmax below
-
-            kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
-            kq = ggml_scale(ctx, kq, 30);
-        }
-
-        if (hparams.attn_soft_cap) {
-            kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
-            kq = ggml_tanh(ctx, kq);
-            kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
-        }
-
-        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
-        cb(kq, "kq_soft_max_ext", il);
-
-        GGML_ASSERT(kv.size == n_ctx);
-
-        // split cached v into n_head heads
-        struct ggml_tensor * v =
-            ggml_view_3d(ctx, kv.v_l[il],
-                    n_kv, n_embd_head_v, n_head_kv,
-                    ggml_element_size(kv.v_l[il])*n_ctx,
-                    ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
-                    0);
-        cb(v, "v", il);
-
-        struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
-        cb(kqv, "kqv", il);
-
-        struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
-        cb(kqv_merged, "kqv_merged", il);
-
-        cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
-        cb(cur, "kqv_merged_cont", il);
-    }
-
-    ggml_build_forward_expand(graph, cur);
-
-    if (wo) {
-        cur = llm_build_lora_mm(lctx, ctx, wo, cur);
-    }
-
-    if (wo_b) {
-        cb(cur, "kqv_wo", il);
-    }
-
-    if (wo_b) {
-        cur = ggml_add(ctx, cur, wo_b);
-    }
-
-    return cur;
-}
-
-static struct ggml_tensor * llm_build_kv(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-       const llama_kv_cache & kv,
-         struct ggml_cgraph * graph,
-         struct ggml_tensor * wo,
-         struct ggml_tensor * wo_b,
-         struct ggml_tensor * k_cur,
-         struct ggml_tensor * v_cur,
-         struct ggml_tensor * q_cur,
-         struct ggml_tensor * kq_mask,
-                    int32_t   n_tokens,
-                    int32_t   kv_head,
-                    int32_t   n_kv,
-                    float     kq_scale,
-         const llm_build_cb & cb,
-                    int       il) {
-    const llama_hparams & hparams = lctx.model.hparams;
-    const llama_cparams & cparams = lctx.cparams;
-
-    // these nodes are added to the graph together so that they are not reordered
-    // by doing so, the number of splits in the graph is reduced
-    ggml_build_forward_expand(graph, q_cur);
-    ggml_build_forward_expand(graph, k_cur);
-    ggml_build_forward_expand(graph, v_cur);
-
-    llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
-
-    struct ggml_tensor * cur;
-
-    cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
-    cb(cur, "kqv_out", il);
-
-    return cur;
-}
-
-static struct ggml_tensor * llm_build_copy_mask_state(
-        struct ggml_context * ctx,
-         struct ggml_cgraph * graph,
-         struct ggml_tensor * s,
-         struct ggml_tensor * state_copy,
-         struct ggml_tensor * state_mask,
-                    int32_t   n_state,
-                    int32_t   kv_size,
-                    int32_t   kv_head,
-                    int32_t   n_kv,
-                    int32_t   n_seqs) {
-    struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size);
-
-    // copy states
-    // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
-    // this shrinks the tensors's ne[1] to n_kv
-    states = ggml_get_rows(ctx, states, state_copy);
-
-    // clear states of sequences which are starting at the beginning of this batch
-    // FIXME: zero-out NANs?
-    states = ggml_mul(ctx, states, state_mask);
-
-    // copy states which won't be changed further (between n_seqs and n_kv)
-    ggml_build_forward_expand(graph,
-        ggml_cpy(ctx,
-            ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)),
-            ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
-
-    // the part of the states that will be used and modified
-    return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
-}
-
-// TODO: split
-static struct ggml_tensor * llm_build_mamba(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-         const llama_ubatch & ubatch,
-         struct ggml_cgraph * graph,
-         struct ggml_tensor * cur,
-         struct ggml_tensor * state_copy,
-         struct ggml_tensor * state_mask,
-                    int32_t   kv_head,
-                    int32_t   n_kv,
-         const llm_build_cb & cb,
-                    int       il) {
-    const llama_model    & model   = lctx.model;
-    const llama_hparams  & hparams = model.hparams;
-    const llama_kv_cache & kv      = lctx.kv_self;
-    const int64_t d_conv  = hparams.ssm_d_conv;
-    const int64_t d_inner = hparams.ssm_d_inner;
-    const int64_t d_state = hparams.ssm_d_state;
-    const int64_t dt_rank = hparams.ssm_dt_rank;
-    const int64_t n_seqs  = ubatch.n_seqs;
-    // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
-    const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
-    // Use the same RMS norm as the final layer norm
-    const float norm_rms_eps = hparams.f_norm_rms_eps;
-
-    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-
-    GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(ubatch.equal_seqs);
-    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
-
-    struct ggml_tensor * conv_states_all = kv.k_l[il];
-    struct ggml_tensor * ssm_states_all  = kv.v_l[il];
-
-    // (ab)using the KV cache to store the states
-    struct ggml_tensor * conv = llm_build_copy_mask_state(ctx,
-            graph, conv_states_all, state_copy, state_mask,
-            hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
-    conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs);
-    struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
-            graph, ssm_states_all, state_copy, state_mask,
-            hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
-    ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs);
-
-    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
-    cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
-
-    // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
-    struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur);
-    // split the above in two
-    // => {d_inner, n_seq_tokens, n_seqs}
-    struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
-    struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz));
-
-    // conv
-    {
-        // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
-        struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, x), 0);
-
-        // copy last (d_conv - 1) columns back into the state cache
-        struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
-
-        ggml_build_forward_expand(graph,
-            ggml_cpy(ctx, last_conv,
-                ggml_view_1d(ctx, conv_states_all,
-                    (d_conv - 1)*(d_inner)*(n_seqs),
-                    kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
-
-        // 1D convolution
-        // The equivalent is to make a self-overlapping view of conv_x
-        // over d_conv columns at each stride in the 3rd dimension,
-        // then element-wise multiply that with the conv1d weight,
-        // then sum the elements of each row,
-        // (the last two steps are a dot product over rows (also doable with mul_mat))
-        // then permute away the ne[0] dimension,
-        // and then you're left with the resulting x tensor.
-        // For simultaneous sequences, all sequences need to have the same length.
-        x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
-
-        // bias
-        x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
-
-        x = ggml_silu(ctx, x);
-    }
-
-    // ssm
-    {
-        // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
-        struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x);
-        // split
-        struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
-        struct ggml_tensor * B  = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
-        struct ggml_tensor * C  = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
-
-        // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
-        if (ssm_dt_b_c_rms) {
-            dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
-            B = ggml_rms_norm(ctx, B, norm_rms_eps);
-            C = ggml_rms_norm(ctx, C, norm_rms_eps);
-        }
-
-        // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
-        dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt);
-        dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
-
-        // Custom operator to optimize the parallel associative scan
-        // as described in the Annex D of the Mamba paper.
-        // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
-        struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
-
-        // store last states
-        ggml_build_forward_expand(graph,
-            ggml_cpy(ctx,
-                ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
-                ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
-
-        struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
-
-        // TODO: skip computing output earlier for unused tokens
-
-        // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
-        y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d));
-        y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z)));
-
-        // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
-        cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y);
-    }
-
-    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
-    cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs);
-    cb(cur, "mamba_out", il);
-
-    return cur;
-}
-
-static struct ggml_tensor * llm_build_rwkv6_time_mix(
-        struct llama_context & lctx,
-        struct ggml_context * ctx,
-        const struct llama_layer * layer,
-        struct ggml_tensor * cur,
-        struct ggml_tensor * x_prev,
-        struct ggml_tensor ** wkv_state,
-        size_t wkv_head_size,
-        size_t head_count_kv) {
-    size_t n_embd       = cur->ne[0];
-    size_t n_seq_tokens = cur->ne[1];
-    size_t n_seqs       = cur->ne[2];
-
-    size_t head_size  = wkv_head_size;
-    size_t head_count = n_embd / head_size;
-
-    size_t n_tokens = n_seqs * n_seq_tokens;
-
-    bool is_qrwkv = layer->time_mix_first == nullptr;
-
-    struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
-
-    sx  = ggml_reshape_2d(ctx, sx,  n_embd, n_tokens);
-    cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
-
-    struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
-
-    xxx = ggml_reshape_4d(
-        ctx,
-        ggml_tanh(
-            ctx,
-            ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
-        ),
-        layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
-    );
-
-    xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2));
-
-    xxx = ggml_mul_mat(
-        ctx,
-        ggml_reshape_4d(
-            ctx,
-            layer->time_mix_w2,
-            layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
-        ),
-        xxx
-    );
-
-    struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
-    if (layer->time_mix_lerp_fused) {
-        // fusing these weights makes some performance improvement
-        sx  = ggml_reshape_3d(ctx, sx,  n_embd, 1, n_tokens);
-        cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
-        xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
-        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
-        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
-        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
-        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
-        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
-    } else {
-        // for backward compatibility
-        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
-        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
-        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
-        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
-        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
-
-        xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
-        xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
-        xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
-        xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
-        xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
-    }
-
-    struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
-    struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk);
-    struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv);
-    if (layer->time_mix_receptance_b) {
-        r = ggml_add(ctx, r, layer->time_mix_receptance_b);
-    }
-    if (layer->time_mix_key_b) {
-        k = ggml_add(ctx, k, layer->time_mix_key_b);
-    }
-    if (layer->time_mix_value_b) {
-        v = ggml_add(ctx, v, layer->time_mix_value_b);
-    }
-
-    struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg);
-    if (is_qrwkv) {
-        g = ggml_sigmoid(ctx, g);
-    } else {
-        g = ggml_silu(ctx, g);
-    }
-
-    if (head_count_kv != head_count) {
-        GGML_ASSERT(head_count % head_count_kv == 0);
-        k = ggml_reshape_4d(ctx, k, head_size, 1, head_count_kv, n_tokens);
-        v = ggml_reshape_4d(ctx, v, head_size, 1, head_count_kv, n_tokens);
-        struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens);
-        k = ggml_repeat(ctx, k, tmp);
-        v = ggml_repeat(ctx, v, tmp);
-    }
-
-    k = ggml_reshape_3d(ctx, k, head_size, head_count, n_tokens);
-    v = ggml_reshape_3d(ctx, v, head_size, head_count, n_tokens);
-    r = ggml_reshape_3d(ctx, r, head_size, head_count, n_tokens);
-
-    struct ggml_tensor * w = ggml_mul_mat(
-        ctx,
-        layer->time_mix_decay_w2,
-        ggml_tanh(
-            ctx,
-            ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
-        )
-    );
-
-    w = ggml_add(ctx, w, layer->time_mix_decay);
-    w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
-    w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens);
-
-    if (is_qrwkv) {
-        // k = k * (1 - w)
-        k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
-    }
-
-    struct ggml_tensor * wkv_output;
-    if (!layer->time_mix_first) {
-        wkv_output = ggml_gated_linear_attn(ctx, k, v, r, w, *wkv_state, pow(head_size, -0.5f));
-    } else {
-        wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
-    }
-    cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
-    *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
-
-    if (!is_qrwkv) {
-        // group norm with head_count groups
-        cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
-        cur = ggml_norm(ctx, cur, 64e-5f);
-
-        // Convert back to regular vectors.
-        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
-        cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
-    } else {
-        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
-    }
-
-    cur = ggml_mul(ctx, cur, g);
-    cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
-
-    return ggml_reshape_3d(ctx, cur, n_embd, n_seq_tokens, n_seqs);
-}
-
-static struct ggml_tensor * llm_build_rwkv6_channel_mix(
-        struct llama_context & lctx,
-        struct ggml_context * ctx,
-        const struct llama_layer * layer,
-        struct ggml_tensor * cur,
-        struct ggml_tensor * x_prev) {
-    struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
-    struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
-    struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
-
-    struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr));
-    struct ggml_tensor * k = ggml_sqr(
-        ctx,
-        ggml_relu(
-            ctx,
-            llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
-        )
-    );
-
-    return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
-}
-
-// block of KV slots to move when defragging
-struct llama_kv_defrag_move {
-    uint32_t src;
-    uint32_t dst;
-    uint32_t len;
-};
-
-struct llm_build_context {
-    const llama_model    & model;
-          llama_context  & lctx;
-    const llama_hparams  & hparams;
-    const llama_cparams  & cparams;
-    const llama_ubatch   & ubatch;
-    const llama_kv_cache & kv_self;
-
-    const int64_t n_embd;
-    const int64_t n_layer;
-    const int64_t n_rot;
-    const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
-    const int64_t n_head;
-    const int64_t n_head_kv;
-    const int64_t n_embd_head_k;
-    const int64_t n_embd_k_gqa;
-    const int64_t n_embd_head_v;
-    const int64_t n_embd_v_gqa;
-    const int64_t n_expert;
-    const int64_t n_expert_used;
-
-    const float freq_base;
-    const float freq_scale;
-    const float ext_factor;
-    const float attn_factor;
-    const float beta_fast;
-    const float beta_slow;
-    const float norm_eps;
-    const float norm_rms_eps;
-
-    const int32_t n_tokens;
-    const int32_t n_kv;     // size of KV cache to consider (n_kv <= kv_self.size)
-    const int32_t n_outputs;
-    const int32_t n_outputs_enc;
-    const int32_t kv_head;  // index of where we store new KV data in the cache
-    const int32_t n_ctx_orig;
-
-    const bool flash_attn;
-
-    const enum llama_pooling_type pooling_type;
-    const enum llama_rope_type    rope_type;
-
-    const llm_build_cb & cb;
-
-    std::vector & buf_compute_meta;
-
-    struct ggml_context * ctx0 = nullptr;
-
-    // TODO: consider making the entire interface noexcept
-    llm_build_context(
-        llama_context  & lctx,
-    const llama_ubatch & ubatch,
-    const llm_build_cb & cb,
-                  bool   worst_case) :
-        model            (lctx.model),
-        lctx             (lctx),
-        hparams          (model.hparams),
-        cparams          (lctx.cparams),
-        ubatch           (ubatch),
-        kv_self          (lctx.kv_self),
-        n_embd           (hparams.n_embd),
-        n_layer          (hparams.n_layer),
-        n_rot            (hparams.n_rot),
-        n_ctx            (cparams.n_ctx),
-        n_head           (hparams.n_head()),
-        n_head_kv        (hparams.n_head_kv()),
-        n_embd_head_k    (hparams.n_embd_head_k),
-        n_embd_k_gqa     (hparams.n_embd_k_gqa()),
-        n_embd_head_v    (hparams.n_embd_head_v),
-        n_embd_v_gqa     (hparams.n_embd_v_gqa()),
-        n_expert         (hparams.n_expert),
-        n_expert_used    (hparams.n_expert_used),
-        freq_base        (cparams.rope_freq_base),
-        freq_scale       (cparams.rope_freq_scale),
-        ext_factor       (cparams.yarn_ext_factor),
-        attn_factor      (cparams.yarn_attn_factor),
-        beta_fast        (cparams.yarn_beta_fast),
-        beta_slow        (cparams.yarn_beta_slow),
-        norm_eps         (hparams.f_norm_eps),
-        norm_rms_eps     (hparams.f_norm_rms_eps),
-        n_tokens         (ubatch.n_tokens),
-        n_kv             (worst_case ? kv_self.size : kv_self.n),
-        n_outputs        (worst_case ? n_tokens : lctx.n_outputs),
-        n_outputs_enc    (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
-        kv_head          (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
-        n_ctx_orig       (cparams.n_ctx_orig_yarn),
-        flash_attn       (cparams.flash_attn),
-        pooling_type     (cparams.pooling_type),
-        rope_type        (hparams.rope_type),
-        cb               (cb),
-        buf_compute_meta (lctx.buf_compute_meta) {
-            // all initializations should be done in init()
-        }
-
-    void init() {
-        struct ggml_init_params params = {
-            /*.mem_size   =*/ buf_compute_meta.size(),
-            /*.mem_buffer =*/ buf_compute_meta.data(),
-            /*.no_alloc   =*/ true,
-        };
-
-        ctx0 = ggml_init(params);
-
-        lctx.inp_tokens      = nullptr;
-        lctx.inp_embd        = nullptr;
-        lctx.inp_pos         = nullptr;
-        lctx.inp_out_ids     = nullptr;
-        lctx.inp_KQ_mask     = nullptr;
-        lctx.inp_KQ_mask_swa = nullptr;
-        lctx.inp_K_shift     = nullptr;
-        lctx.inp_mean        = nullptr;
-        lctx.inp_cls         = nullptr;
-        lctx.inp_s_copy      = nullptr;
-        lctx.inp_s_mask      = nullptr;
-        lctx.inp_s_seq       = nullptr;
-        lctx.inp_pos_bucket    = nullptr;
-        lctx.inp_embd_enc      = nullptr;
-        lctx.inp_KQ_mask_cross = nullptr;
-        lctx.inp_cross_attn_state = nullptr;
-    }
-
-    void free() {
-        ggml_free(ctx0);
-        ctx0 = nullptr;
-    }
-
-    struct ggml_cgraph * build_k_shift() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        GGML_ASSERT(kv_self.size == n_ctx);
-
-        lctx.inp_K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
-        cb(lctx.inp_K_shift, "K_shift", -1);
-        ggml_set_input(lctx.inp_K_shift);
-
-        for (int il = 0; il < n_layer; ++il) {
-            const int64_t n_head_kv = hparams.n_head_kv(il);
-            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-            struct ggml_tensor * rope_factors = build_rope_factors(il);
-            struct ggml_tensor * k =
-                ggml_view_3d(ctx0, kv_self.k_l[il],
-                    n_embd_head_k, n_head_kv, n_ctx,
-                    ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
-                    ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                    0);
-
-            struct ggml_tensor * tmp;
-            if (ggml_is_quantized(k->type)) {
-                // dequantize to f32 -> RoPE -> quantize back
-                tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
-                cb(tmp, "K_f32", il);
-                for (auto & backend : lctx.backends) {
-                    // Figure out which backend KV cache belongs to
-                    if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(kv_self.k_l[il]->buffer))) {
-                        ggml_backend_sched_set_tensor_backend(lctx.sched.get(), tmp, backend.get());
-                        break;
-                    }
-                }
-                tmp = ggml_rope_ext_inplace(ctx0, tmp,
-                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(tmp, "K_shifted_f32", il);
-                tmp = ggml_cpy(ctx0, tmp, k);
-            } else {
-                // we rotate only the first n_rot dimensions
-                tmp = ggml_rope_ext_inplace(ctx0, k,
-                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-            }
-            cb(tmp, "K_shifted", il);
-            ggml_build_forward_expand(gf, tmp);
-        }
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_defrag(const std::vector & moves) {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        for (const auto & move : moves) {
-            for (int il = 0; il < n_layer; ++il) {
-                const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-                const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-                ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
-                        n_embd_k_gqa, move.len,
-                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
-
-                ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
-                        n_embd_k_gqa, move.len,
-                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
-
-                ggml_tensor * view_v_src;
-                ggml_tensor * view_v_dst;
-
-                if (flash_attn) {
-                    // NOTE: the V cache is not transposed when using flash attention
-                    view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            n_embd_v_gqa, move.len,
-                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
-                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
-
-                    view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            n_embd_v_gqa, move.len,
-                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
-                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
-                } else {
-                    view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            move.len, n_embd_v_gqa,
-                            ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
-                            ggml_row_size(kv_self.v_l[il]->type, move.src));
-
-                    view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            move.len, n_embd_v_gqa,
-                            ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
-                            ggml_row_size(kv_self.v_l[il]->type, move.dst));
-                }
-
-                ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
-                ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
-            }
-        }
-
-        //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
-
-        return gf;
-    }
-
-    struct ggml_tensor * build_inp_pos() {
-        lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
-        cb(lctx.inp_pos, "inp_pos", -1);
-        ggml_set_input(lctx.inp_pos);
-        return lctx.inp_pos;
-    }
-
-    struct ggml_tensor * build_rope_factors(int il) {
-        // choose long/short freq factors based on the context size
-        const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
-
-        if (model.layers[il].rope_freqs != nullptr) {
-            return model.layers[il].rope_freqs;
-        }
-
-        if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
-            return model.layers[il].rope_long;
-        }
-
-        return model.layers[il].rope_short;
-    }
-
-    struct ggml_tensor * build_inp_out_ids() {
-        lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
-        cb(lctx.inp_out_ids, "inp_out_ids", -1);
-        ggml_set_input(lctx.inp_out_ids);
-        return lctx.inp_out_ids;
-    }
-
-    struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
-        lctx.inp_KQ_mask = causal
-            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
-            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        cb(lctx.inp_KQ_mask, "KQ_mask", -1);
-        ggml_set_input(lctx.inp_KQ_mask);
-
-        return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
-    }
-
-    struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
-        GGML_ASSERT(hparams.n_swa > 0);
-
-        lctx.inp_KQ_mask_swa = causal
-            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
-            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
-        ggml_set_input(lctx.inp_KQ_mask_swa);
-
-        return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
-    }
-
-    struct ggml_tensor * build_inp_mean() {
-        lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
-        cb(lctx.inp_mean, "inp_mean", -1);
-        ggml_set_input(lctx.inp_mean);
-        return lctx.inp_mean;
-    }
-
-    struct ggml_tensor * build_inp_cls() {
-        lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
-        cb(lctx.inp_cls, "inp_cls", -1);
-        ggml_set_input(lctx.inp_cls);
-        return lctx.inp_cls;
-    }
-
-    struct ggml_tensor * build_inp_s_copy() {
-        lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
-        cb(lctx.inp_s_copy, "inp_s_copy", -1);
-        ggml_set_input(lctx.inp_s_copy);
-        return lctx.inp_s_copy;
-    }
-
-    struct ggml_tensor * build_inp_s_mask() {
-        lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
-        cb(lctx.inp_s_mask, "inp_s_mask", -1);
-        ggml_set_input(lctx.inp_s_mask);
-        return lctx.inp_s_mask;
-    }
-
-    struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
-        // find result_norm tensor for input
-        struct ggml_tensor * inp = nullptr;
-        for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
-            inp = ggml_graph_node(gf, i);
-            if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
-                break;
-            } else {
-                inp = nullptr;
-            }
-        }
-        GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
-
-        struct ggml_tensor * cur;
-
-        switch (pooling_type) {
-            case LLAMA_POOLING_TYPE_NONE:
-                {
-                    cur = inp;
-                } break;
-            case LLAMA_POOLING_TYPE_MEAN:
-                {
-                    struct ggml_tensor * inp_mean = build_inp_mean();
-                    cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
-                } break;
-            case LLAMA_POOLING_TYPE_CLS:
-            case LLAMA_POOLING_TYPE_LAST:
-                {
-                    struct ggml_tensor * inp_cls = build_inp_cls();
-                    cur = ggml_get_rows(ctx0, inp, inp_cls);
-                } break;
-            case LLAMA_POOLING_TYPE_RANK:
-                {
-                    struct ggml_tensor * inp_cls = build_inp_cls();
-                    inp = ggml_get_rows(ctx0, inp, inp_cls);
-
-                    // classification head
-                    // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
-                    GGML_ASSERT(model.cls       != nullptr);
-                    GGML_ASSERT(model.cls_b     != nullptr);
-
-                    cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
-                    cur = ggml_tanh(ctx0, cur);
-
-                    // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
-                    // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
-                    if (model.cls_out) {
-                        GGML_ASSERT(model.cls_out_b != nullptr);
-
-                        cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
-                    }
-                } break;
-            default:
-                {
-                    GGML_ABORT("unknown pooling type");
-                }
-        }
-
-        cb(cur, "result_embd_pooled", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_tensor * llm_build_pos_bucket(bool causal) {
-        if (causal) {
-            lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv,     n_tokens);
-        } else {
-            lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
-        }
-
-        ggml_set_input(lctx.inp_pos_bucket);
-        cb(lctx.inp_pos_bucket, "pos_bucket", -1);
-
-        return lctx.inp_pos_bucket;
-    }
-
-    struct ggml_tensor * llm_build_pos_bias(struct ggml_tensor * pos_bucket, struct ggml_tensor * attn_rel_b) {
-        struct ggml_tensor * pos_bucket_1d = ggml_view_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0);
-        cb(pos_bucket_1d, "pos_bucket_1d", -1);
-
-        struct ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
-        cb(pos_bias, "pos_bias", -1);
-
-        pos_bias = ggml_view_3d(ctx0, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0],  0);
-        cb(pos_bias, "pos_bias", -1);
-
-        pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3);
-        cb(pos_bias, "pos_bias", -1);
-
-        pos_bias = ggml_cont(ctx0, pos_bias);
-        cb(pos_bias, "pos_bias", -1);
-
-        return pos_bias;
-    }
-
-    struct ggml_tensor * llm_build_inp_embd_enc() {
-        const int64_t n_embd = hparams.n_embd;
-        lctx.inp_embd_enc = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc);
-        ggml_set_input(lctx.inp_embd_enc);
-        cb(lctx.inp_embd_enc, "embd_enc", -1);
-        return lctx.inp_embd_enc;
-    }
-
-    struct ggml_tensor * llm_build_inp_KQ_mask_cross() {
-        lctx.inp_KQ_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        ggml_set_input(lctx.inp_KQ_mask_cross);
-        cb(lctx.inp_KQ_mask_cross, "KQ_mask_cross", -1);
-        return lctx.inp_KQ_mask_cross;
-    }
-
-    struct ggml_cgraph * build_llama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            if (model.layers[il].ffn_gate_inp == nullptr) {
-
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            } else {
-                // MoE branch
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_gate_inp,
-                        model.layers[il].ffn_up_exps,
-                        model.layers[il].ffn_gate_exps,
-                        model.layers[il].ffn_down_exps,
-                        nullptr,
-                        n_expert, n_expert_used,
-                        LLM_FFN_SILU, true,
-                        false, 0.0,
-                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                        cb, il);
-                cb(cur, "ffn_moe_out", il);
-            }
-
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        // For Granite architecture
-        if (hparams.f_logit_scale) {
-            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
-        }
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_mllama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-        struct ggml_tensor * inpCAS;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-        inpCAS = llm_build_inp_cross_attn_state(ctx0, lctx, hparams, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            if (hparams.cross_attention_layers(il)) {
-                if (!ubatch.embd && !cparams.cross_attn) {
-                    continue;
-                }
-
-                // cross attention layer
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
-                cb(Qcur, "Qcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                cb(Qcur, "Qcur", il);
-
-                Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
-                cb(Qcur, "Qcur", il);
-
-                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur, * Vcur;
-                if (ubatch.embd) {
-                    Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
-                    cb(Kcur, "Kcur", il);
-
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
-                    cb(Kcur, "Kcur", il);
-
-                    Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
-                    cb(Kcur, "Kcur", il);
-
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
-                    cb(Kcur, "Kcur", il);
-
-                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
-
-                    Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
-                    cb(Vcur, "Vcur", il);
-
-                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404);
-                    cb(Vcur, "Vcur", il);
-
-                    Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
-                    cb(Vcur, "Vcur", il);
-
-                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
-                } else {
-                    Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
-                    cb(Kcur, "Kcur (view)", il);
-
-                    Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
-                    cb(Vcur, "Vcur (view)", il);
-                }
-
-                struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
-                cb(kq, "kq", il);
-
-                // TODO: apply causal masks
-                struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
-                cb(kq_soft_max, "kq_soft_max", il);
-
-                Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
-                cb(Vcur, "Vcur", il);
-
-                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max);
-                cb(kqv, "kqv", il);
-
-                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                cb(kqv_merged, "kqv_merged", il);
-
-                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
-                cb(cur, "kqv_merged_cont", il);
-
-                cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur);
-                cb(cur, "cur", il);
-
-                // TODO: do this in place once?
-                cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate));
-
-                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-                cb(ffn_inp, "ffn_inp", il);
-
-                // feed-forward network
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-
-                // TODO: do this inplace once?
-                cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp);
-                cb(cur, "ffn_out", il);
-
-                cur = lctx.cvec.apply_to(ctx0, cur, il);
-                cb(cur, "l_out", il);
-
-                // input for next layer
-                inpL = cur;
-            } else {
-                // self attention layer
-
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                    model.layers[il].wo, model.layers[il].bo,
-                    Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-
-
-                if (il == n_layer - 1) {
-                    // skip computing output for unused tokens
-                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                    n_tokens = n_outputs;
-                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-                }
-
-                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-                cb(ffn_inp, "ffn_inp", il);
-
-                // feed-forward network
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-
-                cur = ggml_add(ctx0, cur, ffn_inp);
-                cb(cur, "ffn_out", il);
-
-                cur = lctx.cvec.apply_to(ctx0, cur, il);
-                cb(cur, "l_out", il);
-
-                // input for next layer
-                inpL = cur;
-            }
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_deci() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-            const int64_t n_head_kv = hparams.n_head_kv(il);
-            const int64_t n_head    = hparams.n_head(il);
-
-            if (n_head == 0) {
-                // attention-free layer of Llama-3_1-Nemotron-51B
-                cur = inpL;
-            } else {
-                // norm
-                cur = llm_build_norm(ctx0, inpL, hparams,
-                        model.layers[il].attn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_norm", il);
-            }
-
-            if (n_head > 0 && n_head_kv == 0) {
-                // "linear attention" of Llama-3_1-Nemotron-51B
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-                cb(cur, "wo", il);
-            } else if (n_head > 0) {
-                // self-attention
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
-            // modified to support attention-free layer of Llama-3_1-Nemotron-51B
-            struct ggml_tensor * ffn_inp = cur;
-            if (n_head > 0) {
-                ffn_inp = ggml_add(ctx0, cur, inpSA);
-                cb(ffn_inp, "ffn_inp", il);
-            }
-
-            // feed-forward network
-            if (model.layers[il].ffn_gate_inp == nullptr) {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        // For Granite architecture
-        if (hparams.f_logit_scale) {
-            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
-        }
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_baichuan() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                switch (model.type) {
-                    case LLM_TYPE_7B:
-                        Qcur = ggml_rope_ext(
-                            ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                            ext_factor, attn_factor, beta_fast, beta_slow
-                        );
-                        Kcur = ggml_rope_ext(
-                            ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                            ext_factor, attn_factor, beta_fast, beta_slow
-                        );
-                        break;
-                    case LLM_TYPE_13B:
-                        Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
-                        Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
-                        break;
-                    default:
-                        GGML_ABORT("fatal error");
-                }
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_xverse() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,      cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_falcon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * attn_norm;
-
-            attn_norm = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(attn_norm, "attn_norm", il);
-
-            // self-attention
-            {
-                if (model.layers[il].attn_norm_2) {
-                    // Falcon-40B
-                    cur = llm_build_norm(ctx0, inpL, hparams,
-                            model.layers[il].attn_norm_2,
-                            model.layers[il].attn_norm_2_b,
-                            LLM_NORM, cb, il);
-                    cb(cur, "attn_norm_2", il);
-                } else {
-                    cur = attn_norm;
-                }
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                // using mode = 2 for neox mode
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur       = ggml_get_rows(ctx0,       cur, inp_out_ids);
-                inpL      = ggml_get_rows(ctx0,      inpL, inp_out_ids);
-                attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = cur;
-
-            // feed forward
-            {
-                cur = llm_build_ffn(ctx0, lctx, attn_norm, // !! use the attn norm, not the result
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        NULL,                      NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        // norm
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_grok() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // multiply by embedding_multiplier_scale of 78.38367176906169
-        inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            // Grok
-            // if attn_out_norm is present then apply it before adding the input
-            if (model.layers[il].attn_out_norm) {
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].attn_out_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_out_norm", il);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            // MoE branch
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_gate_inp,
-                    model.layers[il].ffn_up_exps,
-                    model.layers[il].ffn_gate_exps,
-                    model.layers[il].ffn_down_exps,
-                    nullptr,
-                    n_expert, n_expert_used,
-                    LLM_FFN_GELU, true,
-                    false, 0.0,
-                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    cb, il);
-            cb(cur, "ffn_moe_out", il);
-
-            // Grok
-            // if layer_out_norm is present then apply it before adding the input
-            // Idea: maybe ffn_out_norm is a better name
-            if (model.layers[il].layer_out_norm) {
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].layer_out_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "layer_out_norm", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        // Grok
-        // multiply logits by output_multiplier_scale of 0.5773502691896257
-
-        cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_dbrx() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                                 model.layers[il].attn_norm, NULL,
-                                 LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = nullptr;
-                struct ggml_tensor * Kcur = nullptr;
-                struct ggml_tensor * Vcur = nullptr;
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                cb(cur, "wqkv_clamped", il);
-
-                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            // MoE branch
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                                 model.layers[il].attn_out_norm, NULL,
-                                 LLM_NORM, cb, il);
-            cb(cur, "attn_out_norm", il);
-
-            cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_gate_inp,
-                    model.layers[il].ffn_up_exps,
-                    model.layers[il].ffn_gate_exps,
-                    model.layers[il].ffn_down_exps,
-                    nullptr,
-                    n_expert, n_expert_used,
-                    LLM_FFN_SILU, true,
-                    false, 0.0,
-                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    cb, il);
-            cb(cur, "ffn_moe_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                             model.output_norm, NULL,
-                             LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_starcoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        struct ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
-        cb(pos, "pos_embd", -1);
-
-        inpL = ggml_add(ctx0, inpL, pos);
-        cb(inpL, "inpL", -1);
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_refact() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                cb(Kcur, "Kcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                cb(Qcur, "Qcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_bert() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-        struct ggml_tensor * inp_pos = nullptr;
-
-        if (model.arch != LLM_ARCH_JINA_BERT_V2) {
-            inp_pos = build_inp_pos();
-        }
-
-        // construct input embeddings (token, type, position)
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // token types are hardcoded to zero ("Sentence A")
-        struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
-        inpL = ggml_add(ctx0, inpL, type_row0);
-        if (model.arch == LLM_ARCH_BERT) {
-            inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
-        }
-        cb(inpL, "inp_embd", -1);
-
-        // embed layer norm
-        inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
-        cb(inpL, "inp_norm", -1);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask(false);
-
-        // iterate layers
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * cur = inpL;
-
-            struct ggml_tensor * Qcur;
-            struct ggml_tensor * Kcur;
-            struct ggml_tensor * Vcur;
-
-            // self-attention
-            if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
-                Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq);
-                cb(Qcur, "Qcur", il);
-
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                            model.layers[il].attn_q_norm,
-                            model.layers[il].attn_q_norm_b,
-                            LLM_NORM, cb, il);
-                }
-
-                Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur), model.layers[il].bk);
-                cb(Kcur, "Kcur", il);
-
-                if (model.layers[il].attn_k_norm) {
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                            model.layers[il].attn_k_norm,
-                            model.layers[il].attn_k_norm_b,
-                            LLM_NORM, cb, il);
-                }
-                Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur), model.layers[il].bv);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-            } else {
-                // compute Q and K and RoPE them
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-            }
-
-            struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-            struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
-
-            struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-            cb(kq, "kq", il);
-
-            kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
-            cb(kq, "kq_soft_max_ext", il);
-
-            struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
-            cb(v, "v", il);
-
-            struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
-            cb(kqv, "kqv", il);
-
-            struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-            cb(kqv_merged, "kqv_merged", il);
-
-            cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-            cb(cur, "kqv_merged_cont", il);
-
-            ggml_build_forward_expand(gf, cur);
-
-            cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-            if (model.layers[il].bo) {
-                cb(cur, "kqv_wo", il);
-            }
-
-            if (model.layers[il].bo) {
-                cur = ggml_add(ctx0, cur, model.layers[il].bo);
-            }
-            cb(cur, "kqv_out", il);
-
-            if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // re-add the layer input
-            cur = ggml_add(ctx0, cur, inpL);
-
-            // attention layer norm
-            cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, cb, il);
-
-            if (model.layers[il].attn_norm_2 != nullptr) {
-                cur = ggml_add(ctx0, cur, inpL); // re-add the layer input
-                cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm_2, model.layers[il].attn_norm_2_b, LLM_NORM, cb, il);
-            }
-
-            struct ggml_tensor * ffn_inp = cur;
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            if (model.arch == LLM_ARCH_BERT) {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-            } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL,                        NULL,
-                        model.layers[il].ffn_gate, NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
-            } else {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            }
-            cb(cur, "ffn_out", il);
-
-            // attentions bypass the intermediate layer
-            cur = ggml_add(ctx0, cur, ffn_inp);
-
-            // output layer norm
-            cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].layer_out_norm, model.layers[il].layer_out_norm_b, LLM_NORM, cb, il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cb(cur, "result_embd", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_bloom() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        inpL = llm_build_norm(ctx0, inpL, hparams,
-                model.tok_norm,
-                model.tok_norm_b,
-                LLM_NORM, cb, -1);
-        cb(inpL, "inp_norm", -1);
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // Add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_mpt() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * pos;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        if (model.pos_embd) {
-            // inp_pos - contains the positions
-            struct ggml_tensor * inp_pos = build_inp_pos();
-            pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
-            cb(pos, "pos_embd", -1);
-
-            inpL = ggml_add(ctx0, inpL, pos);
-            cb(inpL, "inpL", -1);
-        }
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * attn_norm;
-
-            attn_norm = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(attn_norm, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = attn_norm;
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                if (model.layers[il].bqkv){
-                    cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                    cb(cur, "bqkv", il);
-                }
-
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(cur, "wqkv_clamped", il);
-                }
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                // Q/K Layernorm
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                            model.layers[il].attn_q_norm,
-                            model.layers[il].attn_q_norm_b,
-                            LLM_NORM, cb, il);
-                    cb(Qcur, "Qcur", il);
-
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                            model.layers[il].attn_k_norm,
-                            model.layers[il].attn_k_norm_b,
-                            LLM_NORM, cb, il);
-                    cb(Kcur, "Kcur", il);
-
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                    cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                            model.layers[il].wo, model.layers[il].bo,
-                            Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                } else {
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                    cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                            model.layers[il].wo, model.layers[il].bo,
-                            Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                }
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // Add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed forward
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        model.layers[il].ffn_act,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_stablelm() {
-        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            struct ggml_tensor * inpSA = cur;
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                cb(Qcur, "Qcur", il);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                cb(Kcur, "Kcur", il);
-
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                            model.layers[il].attn_q_norm,
-                            NULL,
-                            LLM_NORM, cb, il);
-                    cb(Qcur, "Qcur", il);
-                }
-                if (model.layers[il].attn_k_norm) {
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                            model.layers[il].attn_k_norm,
-                            NULL,
-                            LLM_NORM, cb, il);
-                    cb(Kcur, "Kcur", il);
-                }
-
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpL  = ggml_get_rows(ctx0,  inpL, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                if (model.layers[il].ffn_norm) {
-                    cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                            model.layers[il].ffn_norm,
-                            model.layers[il].ffn_norm_b,
-                            LLM_NORM, cb, il);
-                    cb(cur, "ffn_norm", il);
-                } else {
-                    // parallel residual
-                    cur = inpSA;
-                }
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_qwen() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                // using mode = 2 for neox mode
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward forward
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_qwen2vl() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4);
-        cb(lctx.inp_pos, "inp_pos", -1);
-        ggml_set_input(lctx.inp_pos);
-        struct ggml_tensor * inp_pos = lctx.inp_pos;
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-        int sections[4];
-        std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_multi(
-                    ctx0,
-                    ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_multi(
-                    ctx0,
-                    ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_qwen2moe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self_attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // MoE branch
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            ggml_tensor * moe_out =
-                    llm_build_moe_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_gate_inp,
-                        model.layers[il].ffn_up_exps,
-                        model.layers[il].ffn_gate_exps,
-                        model.layers[il].ffn_down_exps,
-                        nullptr,
-                        n_expert, n_expert_used,
-                        LLM_FFN_SILU, false,
-                        false, 0.0,
-                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                        cb, il);
-            cb(cur, "ffn_moe_out", il);
-
-            // FFN shared expert
-            {
-                ggml_tensor * cur_gate_inp = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
-                cb(cur_gate_inp, "ffn_shexp_gate_inp", il);
-
-                // sigmoid
-                ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp);
-                cb(cur_gate, "ffn_shexp_gate", il);
-
-                ggml_tensor * cur_ffn = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up_shexp,   NULL, NULL,
-                        model.layers[il].ffn_gate_shexp, NULL, NULL,
-                        model.layers[il].ffn_down_shexp, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur_ffn, "ffn_shexp", il);
-
-                ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate);
-                cb(ffn_shexp_out, "ffn_shexp_out", il);
-
-                moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out);
-                cb(moe_out, "ffn_out", il);
-
-                cur = moe_out;
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_phi2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * attn_norm_output;
-        struct ggml_tensor * ffn_output;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(attn_norm_output, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = nullptr;
-                struct ggml_tensor * Kcur = nullptr;
-                struct ggml_tensor * Vcur = nullptr;
-
-                if (model.layers[il].wqkv) {
-                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
-                    cb(cur, "wqkv", il);
-
-                    cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                    cb(cur, "bqkv", il);
-
-                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-                } else {
-                    Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
-                    Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
-                    Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
-                }
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                // with phi2, we scale the Q to avoid precision issues
-                // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur              = ggml_get_rows(ctx0,              cur, inp_out_ids);
-                inpL             = ggml_get_rows(ctx0,             inpL, inp_out_ids);
-                attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
-            }
-
-            // FF
-            {
-                ffn_output = llm_build_ffn(ctx0, lctx, attn_norm_output,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(ffn_output, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_output);
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output_no_bias", -1);
-
-        cur = ggml_add(ctx0, cur, model.output_b);
-        cb(cur, "result_output", -1);
-        ggml_build_forward_expand(gf, cur);
-        return gf;
-    }
-
-    struct ggml_cgraph * build_phi3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = nullptr;
-        if (hparams.n_swa == 0) {
-            // Phi-4 doesn't use sliding window attention
-            KQ_mask = build_inp_KQ_mask();
-        } else {
-            KQ_mask = build_inp_KQ_mask_swa();
-        }
-
-        for (int il = 0; il < n_layer; ++il) {
-            auto residual = inpL;
-
-            // self-attention
-            {
-                // rope freq factors for 128k context
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM_RMS, cb, il);
-                cb(attn_norm_output, "attn_norm", il);
-
-                struct ggml_tensor * Qcur = nullptr;
-                struct ggml_tensor * Kcur = nullptr;
-                struct ggml_tensor * Vcur = nullptr;
-
-                if (model.layers[il].wqkv) {
-                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
-                    cb(cur, "wqkv", il);
-
-                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
-                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
-                } else {
-                    Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
-                    Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
-                    Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
-                }
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor* inp_out_ids = build_inp_out_ids();
-                cur = ggml_get_rows(ctx0, cur, inp_out_ids);
-                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
-            }
-
-            cur = ggml_add(ctx0, cur, residual);
-            residual = cur;
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
-                LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            // feed-forward network
-            if (model.layers[il].ffn_gate_inp == nullptr) {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        NULL,                      NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            } else {
-                // MoE branch
-                cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_gate_inp,
-                        model.layers[il].ffn_up_exps,
-                        model.layers[il].ffn_gate_exps,
-                        model.layers[il].ffn_down_exps,
-                        nullptr,
-                        n_expert, n_expert_used,
-                        LLM_FFN_SILU, true,
-                        false, 0.0,
-                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                        cb, il);
-                cb(cur, "ffn_moe_out", il);
-            }
-
-            cur = ggml_add(ctx0, residual, cur);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-            model.output_norm,
-            model.output_norm_b,
-            LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        if (model.output_b != nullptr) {
-            cb(cur, "result_output_no_bias", -1);
-            cur = ggml_add(ctx0, cur, model.output_b);
-        }
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-
-    struct ggml_cgraph * build_plamo() {
-        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            struct ggml_tensor * attention_norm = cur;
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head,    n_tokens), inp_pos, nullptr,
-                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-            struct ggml_tensor * sa_out = cur;
-
-            cur = attention_norm;
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur    = ggml_get_rows(ctx0,    cur, inp_out_ids);
-                sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
-                inpL   = ggml_get_rows(ctx0,   inpL, inp_out_ids);
-            }
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, sa_out);
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_gpt2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * pos;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
-        cb(pos, "pos_embd", -1);
-
-        inpL = ggml_add(ctx0, inpL, pos);
-        cb(inpL, "inpL", -1);
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_codeshell() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(tmpq, "tmpq", il);
-                cb(tmpk, "tmpk", il);
-                cb(Vcur, "Vcur", il);
-
-                struct ggml_tensor * Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_orion() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                // if (model.layers[il].bq) {
-                //     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                //     cb(Qcur, "Qcur", il);
-                // }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                // if (model.layers[il].bk) {
-                //     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                //     cb(Kcur, "Kcur", il);
-                // }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                // if (model.layers[il].bv) {
-                //     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                //     cb(Vcur, "Vcur", il);
-                // }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_internlm2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_minicpm3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        //TODO: if the model varies, these parameters need to be read from the model
-        const int64_t n_embd_base = 256;
-        const float scale_embd  = 12.0f;
-        const float scale_depth = 1.4f;
-        const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
-
-        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
-        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-        const uint32_t kv_lora_rank = hparams.n_lora_kv;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // scale the input embeddings
-        inpL = ggml_scale(ctx0, inpL, scale_embd);
-        cb(inpL, "inp_scaled", -1);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            struct ggml_tensor * rope_factors = build_rope_factors(il);
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self_attention
-            {
-                struct ggml_tensor * q = NULL;
-                // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
-                q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
-                cb(q, "q", il);
-
-                q = llm_build_norm(ctx0, q, hparams,
-                        model.layers[il].attn_q_a_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(q, "q", il);
-
-                // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
-                q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
-                cb(q, "q", il);
-
-                // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
-                        ggml_row_size(q->type, hparams.n_embd_head_k),
-                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
-                        0);
-                cb(q_nope, "q_nope", il);
-
-                // and {n_head * n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
-                        ggml_row_size(q->type, hparams.n_embd_head_k),
-                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
-                        ggml_row_size(q->type, n_embd_head_qk_nope));
-                cb(q_pe, "q_pe", il);
-
-                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
-                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
-
-                // split into {kv_lora_rank, n_tokens}
-                struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
-                        kv_pe_compresseed->nb[1],
-                        0);
-                cb(kv_compressed, "kv_compressed", il);
-
-                // and {n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
-                        kv_pe_compresseed->nb[1],
-                        kv_pe_compresseed->nb[1],
-                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
-                cb(k_pe, "k_pe", il);
-
-                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
-                kv_compressed = ggml_cont(ctx0, kv_compressed);
-                kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
-                        model.layers[il].attn_kv_a_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(kv_compressed, "kv_compressed", il);
-
-                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
-                struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
-                cb(kv, "kv", il);
-
-                // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
-                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
-                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                        0);
-                cb(k_nope, "k_nope", il);
-
-                // and {n_head * n_embd_head_v, n_tokens}
-                struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
-                cb(v_states, "v_states", il);
-
-                v_states = ggml_cont(ctx0, v_states);
-                cb(v_states, "v_states", il);
-
-                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
-                    ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
-                    0);
-                cb(v_states, "v_states", il);
-
-                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
-                q_pe = ggml_rope_ext(
-                    ctx0, q_pe, inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(q_pe, "q_pe", il);
-
-                // shared RoPE key
-                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
-                k_pe = ggml_rope_ext(
-                    ctx0, k_pe, inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(k_pe, "k_pe", il);
-
-                struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
-                cb(q_states, "q_states", il);
-
-                struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
-                cb(k_states, "k_states", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            // scale_res - scale the hidden states for residual connection
-            const float scale_res = scale_depth/sqrtf(float(n_layer));
-            cur = ggml_scale(ctx0, cur, scale_res);
-            cb(cur, "hidden_scaled", il);
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            // scale the hidden states for residual connection
-            cur = ggml_scale(ctx0, cur, scale_res);
-            cb(cur, "hidden_scaled_ffn", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head scaling
-        const float scale_lmhead = float(n_embd_base)/float(n_embd);
-        cur = ggml_scale(ctx0, cur, scale_lmhead);
-        cb(cur, "lmhead_scaling", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_gemma() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-        cb(inpL, "inp_scaled", -1);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Qcur, "Qcur", il);
-
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
-                cb(Qcur, "Qcur_scaled", il);
-
-                Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
-            cb(sa_out, "sa_out", il);
-
-            cur = llm_build_norm(ctx0, sa_out, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, sa_out);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_gemma2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-        cb(inpL, "inp_scaled", -1);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        // gemma 2 requires different mask for layers using sliding window (SWA)
-        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask(true);
-        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
-
-        for (int il = 0; il < n_layer; ++il) {
-            // (il % 2) layers use SWA
-            struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Qcur, "Qcur", il);
-
-                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
-                switch (model.type) {
-                    case LLM_TYPE_2B:
-                    case LLM_TYPE_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
-                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
-                    default: GGML_ABORT("fatal error");
-                };
-                cb(Qcur, "Qcur_scaled", il);
-
-                Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-            }
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.layers[il].attn_post_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_post_norm", il);
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
-            cb(sa_out, "sa_out", il);
-
-            cur = llm_build_norm(ctx0, sa_out, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_post_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-            cb(cur, "ffn_post_norm", -1);
-
-            cur = ggml_add(ctx0, cur, sa_out);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        // final logit soft-capping
-        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
-        cur = ggml_tanh(ctx0, cur);
-        cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-
-    struct ggml_cgraph * build_starcoder2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_mamba() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        struct ggml_tensor * state_copy = build_inp_s_copy();
-        struct ggml_tensor * state_mask = build_inp_s_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            cur = llm_build_mamba(ctx0, lctx, ubatch, gf, cur,
-                    state_copy, state_mask,
-                    kv_head, n_kv, cb, il);
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // residual
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        // final rmsnorm
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_command_r() {
-
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        const float f_logit_scale = hparams.f_logit_scale;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-            struct ggml_tensor * ffn_inp = cur;
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
-                                ggml_element_size(Qcur) * n_embd_head,
-                                ggml_element_size(Qcur) * n_embd_head * n_head,
-                                0);
-                    cb(Qcur, "Qcur", il);
-                    Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
-                                ggml_element_size(Kcur) * n_embd_head,
-                                ggml_element_size(Kcur) * n_embd_head * n_head_kv,
-                                0);
-                    cb(Kcur, "Kcur", il);
-
-                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                                model.layers[il].attn_q_norm,
-                                NULL,
-                                LLM_NORM, cb, il);
-                    cb(Qcur, "Qcur", il);
-
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                            model.layers[il].attn_k_norm,
-                            NULL,
-                            LLM_NORM, cb, il);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur     = ggml_get_rows(ctx0,     cur, inp_out_ids);
-                inpL    = ggml_get_rows(ctx0,    inpL, inp_out_ids);
-                ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
-            }
-
-            struct ggml_tensor * attn_out = cur;
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, ffn_inp,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            // add together residual + FFN + self-attention
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = ggml_add(ctx0, cur, attn_out);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        if (f_logit_scale) {
-            cur = ggml_scale(ctx0, cur, f_logit_scale);
-        }
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-
-    }
-
-    struct ggml_cgraph * build_cohere2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        const float f_logit_scale = hparams.f_logit_scale;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        // cohere2 requires different mask for layers using sliding window (SWA)
-        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask();
-        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
-
-        // sliding window switch pattern
-        const int32_t sliding_window_pattern = 4;
-
-        for (int il = 0; il < n_layer; ++il) {
-            // three layers sliding window attention (window size 4096) and ROPE
-            // fourth layer uses global attention without positional embeddings
-            const bool           is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
-            struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-            struct ggml_tensor * ffn_inp = cur;
-
-            // self-attention
-            {
-                // rope freq factors for 128k context
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                if (is_sliding) {
-                    Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
-                                        beta_fast, beta_slow);
-                    cb(Qcur, "Qcur", il);
-
-                    Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                                        rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
-                                        attn_factor, beta_fast, beta_slow);
-                    cb(Kcur, "Kcur", il);
-                } else {
-                    // For non-sliding layers, just reshape without applying RoPE
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                    cb(Qcur, "Qcur", il);
-
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
-                                   KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur                              = ggml_get_rows(ctx0, cur, inp_out_ids);
-                inpL                             = ggml_get_rows(ctx0, inpL, inp_out_ids);
-                ffn_inp                          = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
-            }
-
-            struct ggml_tensor * attn_out = cur;
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
-                                    NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
-                                    cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            // add together residual + FFN + self-attention
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = ggml_add(ctx0, cur, attn_out);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        if (f_logit_scale) {
-            cur = ggml_scale(ctx0, cur, f_logit_scale);
-        }
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    // ref: https://allenai.org/olmo
-    // based on the original build_llama() function, changes:
-    //   * non-parametric layer norm
-    //   * clamp qkv
-    //   * removed bias
-    //   * removed MoE
-    struct ggml_cgraph * build_olmo() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    NULL, NULL,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, nullptr,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    NULL, NULL,
-                    LLM_NORM, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                NULL, NULL,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_olmo2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = inpL;
-
-            // self_attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Qcur, "Qcur_normed", il);
-
-                Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Kcur, "Kcur_normed", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur_rope", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur_rope", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.layers[il].attn_post_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_post_norm", il);
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_ffn(ctx0, lctx, ffn_inp,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_post_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-            cb(cur, "ffn_post_norm", -1);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    // based on the build_qwen2moe() function, changes:
-    //   * removed shared experts
-    //   * removed bias
-    //   * added q, k norm
-    struct ggml_cgraph * build_olmoe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self_attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Qcur, "Qcur_normed", il);
-
-                Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Kcur, "Kcur_normed", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur_rope", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur_rope", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // MoE branch
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_gate_inp,
-                    model.layers[il].ffn_up_exps,
-                    model.layers[il].ffn_gate_exps,
-                    model.layers[il].ffn_down_exps,
-                    nullptr,
-                    n_expert, n_expert_used,
-                    LLM_FFN_SILU, false,
-                    false, 0.0,
-                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    cb, il);
-            cb(cur, "ffn_moe_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_openelm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            const int64_t n_head    = hparams.n_head(il);
-            const int64_t n_head_kv = hparams.n_head_kv(il);
-            const int64_t n_head_qkv = 2*n_head_kv + n_head;
-
-            cur = inpL;
-            struct ggml_tensor * residual = cur;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0));
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head));
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
-                cb(Vcur, "Vcur", il);
-
-                Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                        model.layers[il].attn_q_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Qcur, "Qcur", il);
-
-                Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                        model.layers[il].attn_k_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Kcur, "Kcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens);
-                cb(Qcur, "Vcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
-                cur = ggml_get_rows(ctx0, cur, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        // norm
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_gptneox() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // ffn
-            if (hparams.use_par_res) {
-                // attention and ffn are computed in parallel
-                // x = x + attn(ln1(x)) + ffn(ln2(x))
-
-                struct ggml_tensor * attn_out = cur;
-
-                cur = llm_build_norm(ctx0, inpL, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-
-                cur = ggml_add(ctx0, cur, inpL);
-                cb(cur, "ffn_out", il);
-
-                cur = ggml_add(ctx0, cur, attn_out);
-                cur = lctx.cvec.apply_to(ctx0, cur, il);
-                cb(cur, "l_out", il);
-
-                // input for next layer
-                inpL = cur;
-            } else {
-                // attention and ffn are computed sequentially
-                // x = x + attn(ln1(x))
-                // x = x + ffn(ln2(x))
-
-                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-                cb(ffn_inp, "ffn_inp", il);
-
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-
-                cur = ggml_add(ctx0, cur, ffn_inp);
-                cur = lctx.cvec.apply_to(ctx0, cur, il);
-                cb(cur, "l_out", il);
-
-                // input for next layer
-                inpL = cur;
-            }
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_arctic() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            struct ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp);
-            cb(ffn_out, "ffn_out", il);
-
-            // MoE
-            cur = llm_build_norm(ctx0, inpSA, hparams,
-                    model.layers[il].ffn_norm_exps, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm_exps", il);
-
-            cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_gate_inp,
-                    model.layers[il].ffn_up_exps,
-                    model.layers[il].ffn_gate_exps,
-                    model.layers[il].ffn_down_exps,
-                    nullptr,
-                    n_expert, n_expert_used,
-                    LLM_FFN_SILU, true,
-                    false, 0.0,
-                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    cb, il);
-            cb(cur, "ffn_moe_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_out);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_deepseek() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            if ((uint32_t) il < hparams.n_layer_dense_lead) {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            } else {
-                // MoE branch
-                ggml_tensor * moe_out =
-                        llm_build_moe_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_gate_inp,
-                            model.layers[il].ffn_up_exps,
-                            model.layers[il].ffn_gate_exps,
-                            model.layers[il].ffn_down_exps,
-                            nullptr,
-                            n_expert, n_expert_used,
-                            LLM_FFN_SILU, false,
-                            false, hparams.expert_weights_scale,
-                            LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                            cb, il);
-                cb(moe_out, "ffn_moe_out", il);
-
-                // FFN shared expert
-                {
-                    ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_up_shexp,   NULL, NULL,
-                            model.layers[il].ffn_gate_shexp, NULL, NULL,
-                            model.layers[il].ffn_down_shexp, NULL, NULL,
-                            NULL,
-                            LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                    cb(ffn_shexp, "ffn_shexp", il);
-
-                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
-                    cb(cur, "ffn_out", il);
-                }
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_deepseek2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        bool is_lite = (hparams.n_layer == 27);
-
-        // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
-        // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-        const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
-        const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
-        const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
-
-        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
-        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-        const uint32_t kv_lora_rank = hparams.n_lora_kv;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self_attention
-            {
-                struct ggml_tensor * q = NULL;
-                if (!is_lite) {
-                    // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
-                    q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
-                    cb(q, "q", il);
-
-                    q = llm_build_norm(ctx0, q, hparams,
-                            model.layers[il].attn_q_a_norm, NULL,
-                            LLM_NORM_RMS, cb, il);
-                    cb(q, "q", il);
-
-                    // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
-                    q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
-                    cb(q, "q", il);
-                } else {
-                    q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
-                    cb(q, "q", il);
-                }
-
-                // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
-                        ggml_row_size(q->type, hparams.n_embd_head_k),
-                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
-                        0);
-                cb(q_nope, "q_nope", il);
-
-                // and {n_head * n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
-                        ggml_row_size(q->type, hparams.n_embd_head_k),
-                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
-                        ggml_row_size(q->type, n_embd_head_qk_nope));
-                cb(q_pe, "q_pe", il);
-
-                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
-                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
-
-                // split into {kv_lora_rank, n_tokens}
-                struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
-                        kv_pe_compresseed->nb[1],
-                        0);
-                cb(kv_compressed, "kv_compressed", il);
-
-                // and {n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
-                        kv_pe_compresseed->nb[1],
-                        kv_pe_compresseed->nb[1],
-                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
-                cb(k_pe, "k_pe", il);
-
-                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
-                kv_compressed = ggml_cont(ctx0, kv_compressed);
-                kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
-                        model.layers[il].attn_kv_a_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(kv_compressed, "kv_compressed", il);
-
-                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
-                struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
-                cb(kv, "kv", il);
-
-                // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
-                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
-                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                        0);
-                cb(k_nope, "k_nope", il);
-
-                // and {n_head * n_embd_head_v, n_tokens}
-                struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
-                cb(v_states, "v_states", il);
-
-                v_states = ggml_cont(ctx0, v_states);
-                cb(v_states, "v_states", il);
-
-                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
-                    ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
-                    0);
-                cb(v_states, "v_states", il);
-
-                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
-                q_pe = ggml_rope_ext(
-                    ctx0, q_pe, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor_scaled, beta_fast, beta_slow
-                );
-                cb(q_pe, "q_pe", il);
-
-                // shared RoPE key
-                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
-                k_pe = ggml_rope_ext(
-                    ctx0, k_pe, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor_scaled, beta_fast, beta_slow
-                );
-                cb(k_pe, "k_pe", il);
-
-                struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
-                cb(q_states, "q_states", il);
-
-                struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
-                cb(k_states, "k_states", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            if ((uint32_t) il < hparams.n_layer_dense_lead) {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            } else {
-                // MoE branch
-                ggml_tensor * moe_out =
-                        llm_build_moe_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_gate_inp,
-                            model.layers[il].ffn_up_exps,
-                            model.layers[il].ffn_gate_exps,
-                            model.layers[il].ffn_down_exps,
-                            model.layers[il].ffn_exp_probs_b,
-                            n_expert, n_expert_used,
-                            LLM_FFN_SILU, hparams.expert_weights_norm,
-                            true, hparams.expert_weights_scale,
-                            (enum llama_expert_gating_func_type) hparams.expert_gating_func,
-                            cb, il);
-                cb(moe_out, "ffn_moe_out", il);
-
-                // FFN shared expert
-                {
-                    ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_up_shexp,   NULL, NULL,
-                            model.layers[il].ffn_gate_shexp, NULL, NULL,
-                            model.layers[il].ffn_down_shexp, NULL, NULL,
-                            NULL,
-                            LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                    cb(ffn_shexp, "ffn_shexp", il);
-
-                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
-                    cb(cur, "ffn_out", il);
-                }
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_bitnet() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                if (model.layers[il].wq_scale) {
-                    Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
-                }
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                // B1.K
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                if (model.layers[il].wk_scale) {
-                    Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
-                }
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                // B1.V
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                if (model.layers[il].wv_scale) {
-                    Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
-                }
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        NULL, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].attn_sub_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_sub_norm", il);
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-                if (model.layers[il].wo_scale) {
-                    cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
-                }
-                if (model.layers[il].bo) {
-                    cur = ggml_add(ctx0, cur, model.layers[il].bo);
-                }
-                cb(cur, "attn_o_out", il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward forward
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, model.layers[il].ffn_up_scale,
-                    model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
-                    NULL,                      NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_sub_out", il);
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                            model.layers[il].ffn_sub_norm, NULL,
-                            LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_sub_norm", il);
-
-            cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur);
-            if (model.layers[il].ffn_down_scale) {
-                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
-            }
-            cb(cur, "ffn_down", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        // FIXME: do not use model.tok_embd directly, duplicate as model.output
-        cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-        return gf;
-    }
-
-    struct ggml_cgraph * build_t5_enc() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        GGML_ASSERT(lctx.is_encoding);
-        struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm_enc, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-                struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
-
-                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                cb(kq, "kq", il);
-
-                struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
-                struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
-                struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
-                cb(kq_b, "kq_b", il);
-
-                kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
-                cb(kq, "kq_soft_max_ext", il);
-
-                struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
-                cb(v, "v", il);
-
-                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
-                cb(kqv, "kqv", il);
-
-                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                cb(kqv_merged, "kqv_merged", il);
-
-                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                cb(cur, "kqv_merged_cont", il);
-
-                ggml_build_forward_expand(gf, cur);
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur);
-                cb(cur, "kqv_out", il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm_enc, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                // T5 uses relu, flan-T5 uses gelu-gated
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up_enc,   NULL, NULL,
-                        model.layers[il].ffn_gate_enc, NULL, NULL,
-                        model.layers[il].ffn_down_enc, NULL, NULL,
-                        NULL,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR  : LLM_FFN_SEQ,
-                        cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = ggml_add(ctx0, cur, layer_dir);
-            }
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-        cb(cur, "result_embd", -1);
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm_enc, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_t5_dec() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        GGML_ASSERT(!lctx.is_encoding);
-        GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
-
-        struct ggml_tensor * embd_enc       = llm_build_inp_embd_enc();
-        struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
-
-        struct ggml_tensor * KQ_mask_dec   = build_inp_KQ_mask();
-        struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
-
-                struct ggml_tensor * k =
-                    ggml_view_3d(ctx0, kv_self.k_l[il],
-                            n_embd_head_k, n_kv, n_head_kv,
-                            ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                            ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
-                            0);
-                cb(k, "k", il);
-
-                struct ggml_tensor * v =
-                    ggml_view_3d(ctx0, kv_self.v_l[il],
-                            n_kv, n_embd_head_v, n_head_kv,
-                            ggml_element_size(kv_self.v_l[il])*n_ctx,
-                            ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
-                            0);
-                cb(v, "v", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-
-                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                cb(kq, "kq", il);
-
-                struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
-                struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
-                struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
-                cb(kq_b, "kq_b", il);
-
-                kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
-                cb(kq, "kq_soft_max_ext", il);
-
-                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
-                cb(kqv, "kqv", il);
-
-                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                cb(kqv_merged, "kqv_merged", il);
-
-                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                cb(cur, "kqv_merged_cont", il);
-
-                ggml_build_forward_expand(gf, cur);
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-                cb(cur, "kqv_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, inpSA);
-            cb(cur, "cross_inp", il);
-
-            struct ggml_tensor * inpCA = cur;
-
-            // norm
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.layers[il].attn_norm_cross, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm_cross", il);
-
-            // cross-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
-
-                struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-                struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
-
-                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                cb(kq, "kq", il);
-
-                kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
-                cb(kq, "kq_soft_max_ext", il);
-
-                struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
-                cb(v, "v", il);
-
-                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
-                cb(kqv, "kqv", il);
-
-                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                cb(kqv_merged, "kqv_merged", il);
-
-                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                cb(cur, "kqv_merged_cont", il);
-
-                ggml_build_forward_expand(gf, cur);
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur);
-                cb(cur, "kqv_out", il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-                inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                // T5 uses relu, flan-T5 uses gelu-gated
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
-                        cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = ggml_add(ctx0, cur, layer_dir);
-            }
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-        cb(cur, "result_embd", -1);
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_jais() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            inpL = ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_chatglm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = nullptr;
-                struct ggml_tensor * Kcur = nullptr;
-                struct ggml_tensor * Vcur = nullptr;
-                if (model.layers[il].wqkv == nullptr) {
-                    Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                    if (model.layers[il].bq) {
-                        Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    }
-                    Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                    if (model.layers[il].bk) {
-                        Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    }
-                    Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                    if (model.layers[il].bv) {
-                        Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    }
-                } else {
-                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                    cb(cur, "wqkv", il);
-                    if (model.layers[il].bqkv) {
-                        cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                        cb(cur, "bqkv", il);
-                    }
-                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-                }
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-                //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur_rope", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur_rope", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            // Add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        NULL,                      NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-
-            }
-
-            inpL = ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_nemotron() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        //GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm,
-                    model.layers[il].ffn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                    NULL,                      NULL,                        NULL,
-                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                    NULL,
-                    LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_exaone() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    ggml_cgraph * build_rwkv6() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // Token shift state dimensions should be 2 * n_emb
-        GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
-
-        const int64_t n_seqs = ubatch.n_seqs;
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-        const int64_t n_tokens = ubatch.n_tokens;
-        GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
-        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-        struct ggml_tensor * state_copy = build_inp_s_copy();
-        struct ggml_tensor * state_mask = build_inp_s_mask();
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-        inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
-
-        for (int il = 0; il < n_layer; ++il) {
-            const llama_layer * layer = &model.layers[il];
-
-            // (ab)using the KV cache to store the states
-            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
-                    gf, kv_self.k_l[il], state_copy, state_mask,
-                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
-            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
-                    gf, kv_self.v_l[il], state_copy, state_mask,
-                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
-
-            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
-            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs);
-
-            struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
-            struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
-
-            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il);
-            struct ggml_tensor * x_prev = ggml_concat(
-                ctx0,
-                att_shift,
-                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
-                1
-            );
-
-            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size));
-            ggml_build_forward_expand(gf, cur);
-            ggml_build_forward_expand(
-                gf,
-                ggml_cpy(
-                    ctx0,
-                    wkv_states,
-                    ggml_view_1d(
-                        ctx0,
-                        kv_self.v_l[il],
-                        hparams.n_embd_v_s() * n_seqs,
-                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
-                    )
-                )
-            );
-
-            struct ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il);
-            x_prev = ggml_concat(
-                ctx0,
-                ffn_shift,
-                ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
-                1
-            );
-            cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
-            ggml_build_forward_expand(gf, cur);
-
-            struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
-            struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn));
-
-            token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1);
-
-            ggml_build_forward_expand(
-                gf,
-                ggml_cpy(
-                    ctx0,
-                    ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
-                    ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
-                )
-            );
-
-            if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
-                cur = ggml_scale(ctx0, cur, 0.5F);
-            }
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
-        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
-
-        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
-    ggml_cgraph * build_rwkv6qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
-
-        const int64_t n_seqs = ubatch.n_seqs;
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-        const int64_t n_tokens = ubatch.n_tokens;
-        GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
-        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-        struct ggml_tensor * state_copy = build_inp_s_copy();
-        struct ggml_tensor * state_mask = build_inp_s_mask();
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        for (int il = 0; il < n_layer; ++il) {
-            const llama_layer * layer = &model.layers[il];
-
-            // (ab)using the KV cache to store the states
-            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
-                    gf, kv_self.k_l[il], state_copy, state_mask,
-                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
-            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
-                    gf, kv_self.v_l[il], state_copy, state_mask,
-                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
-
-            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
-            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
-
-            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
-            struct ggml_tensor * x_prev = ggml_concat(
-                ctx0,
-                token_shift,
-                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
-                1
-            );
-
-            struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
-            ggml_build_forward_expand(
-                gf,
-                ggml_cpy(
-                    ctx0,
-                    ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0),
-                    ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
-                )
-            );
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv()));
-            ggml_build_forward_expand(gf, ffn_inp);
-            ggml_build_forward_expand(
-                gf,
-                ggml_cpy(
-                    ctx0,
-                    wkv_states,
-                    ggml_view_1d(
-                        ctx0,
-                        kv_self.v_l[il],
-                        hparams.n_embd_v_s() * n_seqs,
-                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
-                    )
-                )
-            );
-
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
-        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
-
-        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    // ref: https://github.com/facebookresearch/chameleon
-    // based on the original build_llama() function, changes:
-    //   * qk-norm
-    //   * swin-norm
-    //   * removed bias
-    //   * removed MoE
-    struct ggml_cgraph * build_chameleon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            if (hparams.swin_norm) {
-                cur = inpL;
-            } else {
-                cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_norm", il);
-            }
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
-                                ggml_element_size(Qcur) * n_embd_head,
-                                ggml_element_size(Qcur) * n_embd_head * n_head,
-                                0);
-                    cb(Qcur, "Qcur", il);
-
-                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                                model.layers[il].attn_q_norm,
-                                model.layers[il].attn_q_norm_b,
-                                LLM_NORM, cb, il);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                if (model.layers[il].attn_k_norm) {
-                    Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
-                                ggml_element_size(Kcur) * n_embd_head,
-                                ggml_element_size(Kcur) * n_embd_head * n_head_kv,
-                                0);
-                    cb(Kcur, "Kcur", il);
-
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                               model.layers[il].attn_k_norm,
-                               model.layers[il].attn_k_norm_b,
-                               LLM_NORM, cb, il);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, nullptr,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-
-                if (hparams.swin_norm) {
-                    cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].attn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                }
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            if (!hparams.swin_norm) {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-            }
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            if (hparams.swin_norm) {
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output_with_img_logits", -1);
-
-        // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs.
-        // Needs to be removed once image outputs are supported.
-        int img_token_end_idx = 8196;
-        int img_token_start_idx = 4;
-        int num_img_tokens = img_token_end_idx - img_token_start_idx;
-        // creates 1d tensor of size num_img_tokens and values -FLT_MAX,
-        // which ensures that text token values are always at least larger than image token values
-        struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens);
-        img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX);
-        cb(img_logits, "img_logits", -1);
-        cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
-        cb(cur, "result_output", -1);
-        ggml_build_forward_expand(gf, cur);
-        return gf;
-   }
-
-   ggml_cgraph * build_solar() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        struct ggml_tensor * bskcn_1;
-        struct ggml_tensor * bskcn_2;
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            if (hparams.n_bskcn(0, il)) {
-                bskcn_1 = inpSA;
-            }
-
-            if (hparams.n_bskcn(1, il)) {
-                bskcn_2 = inpSA;
-            }
-
-            if (hparams.n_bskcn(2, il)) {
-                inpSA = ggml_add(
-                   ctx0,
-                   ggml_mul(ctx0, bskcn_1, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
-                   ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
-            }
-
-            if (hparams.n_bskcn(3, il)) {
-                inpSA = ggml_add(
-                   ctx0,
-                   ggml_mul(ctx0, bskcn_2, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
-                   ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
-            }
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-        ggml_build_forward_expand(gf, cur);
-        return gf;
-    }
-
-    struct ggml_cgraph * build_wavtokenizer_dec() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL));
-
-        cur = ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1);
-        cur = ggml_add(ctx0, cur, model.conv1d_b);
-
-        // posnet
-        for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) {
-            const auto & layer = model.layers[il].posnet;
-
-            inpL = cur;
-
-            switch (il) {
-                case 0:
-                case 1:
-                case 3:
-                case 4:
-                    {
-                        cur = llm_build_norm(ctx0, cur, hparams,
-                                layer.norm1,
-                                layer.norm1_b,
-                                LLM_NORM_GROUP, cb, 0);
-
-                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
-
-                        cur = ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1);
-                        cur = ggml_add(ctx0, cur, layer.conv1_b);
-
-                        cur = llm_build_norm(ctx0, cur, hparams,
-                                layer.norm2,
-                                layer.norm2_b,
-                                LLM_NORM_GROUP, cb, 0);
-
-                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
-
-                        cur = ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1);
-                        cur = ggml_add(ctx0, cur, layer.conv2_b);
-
-                        cur = ggml_add(ctx0, cur, inpL);
-                    } break;
-                case 2:
-                    {
-                        cur = llm_build_norm(ctx0, cur, hparams,
-                                layer.attn_norm,
-                                layer.attn_norm_b,
-                                LLM_NORM_GROUP, cb, 0);
-
-                        struct ggml_tensor * q;
-                        struct ggml_tensor * k;
-                        struct ggml_tensor * v;
-
-                        q = ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1);
-                        k = ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1);
-                        v = ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1);
-
-                        q = ggml_add(ctx0, q, layer.attn_q_b);
-                        k = ggml_add(ctx0, k, layer.attn_k_b);
-                        v = ggml_add(ctx0, v, layer.attn_v_b);
-
-                        q = ggml_cont(ctx0, ggml_transpose(ctx0, q));
-                        k = ggml_cont(ctx0, ggml_transpose(ctx0, k));
-
-                        struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-
-                        kq = ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f);
-
-                        cur = ggml_mul_mat(ctx0, kq, v);
-
-                        cur = ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1);
-                        cur = ggml_add(ctx0, cur, layer.attn_o_b);
-
-                        cur = ggml_add(ctx0, cur, inpL);
-                    } break;
-                case 5:
-                    {
-                        cur = llm_build_norm(ctx0, cur, hparams,
-                                layer.norm,
-                                layer.norm_b,
-                                LLM_NORM_GROUP, cb, 0);
-                    } break;
-                default: GGML_ABORT("unknown posnet layer");
-            };
-        }
-
-        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.tok_norm,
-                model.tok_norm_b,
-                LLM_NORM, cb, -1);
-
-        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
-
-        inpL = cur;
-
-        // convnext
-        for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) {
-            const auto & layer = model.layers[il].convnext;
-
-            cur = inpL;
-
-            cur = ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1);
-            cur = ggml_add(ctx0, cur, layer.dw_b);
-
-            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    layer.norm,
-                    layer.norm_b,
-                    LLM_NORM, cb, -1);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    layer.pw1, layer.pw1_b, NULL,
-                    NULL,      NULL,        NULL,
-                    layer.pw2, layer.pw2_b, NULL,
-                    NULL,
-                    LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-
-            cur = ggml_mul(ctx0, cur, layer.gamma);
-
-            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
-
-            inpL = ggml_add(ctx0, cur, inpL);
-        }
-
-        cur = inpL;
-
-        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        cur = ggml_add(ctx0, cur, model.output_b);
-        cb(cur, "result_embd", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-};
-
-static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & moves) {
-    llama_ubatch dummy = {};
-    dummy.equal_seqs = true;
-
-    llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
-
-    struct llm_build_context llm(lctx, dummy, cb, false);
-
-    llm.init();
-
-    struct ggml_cgraph * result = llm.build_defrag(moves);
-
-    llm.free();
-
-    return result;
-}
-
-static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
-    llama_ubatch dummy = {};
-    dummy.equal_seqs = true;
-
-    llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
-
-    struct llm_build_context llm(lctx, dummy, cb, false);
-
-    llm.init();
-
-    struct ggml_cgraph * result = llm.build_k_shift();
-
-    llm.free();
-
-    return result;
-}
-
-static struct ggml_cgraph * llama_build_graph(
-         llama_context & lctx,
-    const llama_ubatch & ubatch,
-                  bool   worst_case) {
-    const auto & model = lctx.model;
-
-    // this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
-    llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
-        if (il >= 0) {
-            ggml_format_name(cur, "%s-%d", name, il);
-        } else {
-            ggml_set_name(cur, name);
-        }
-
-        if (!lctx.cparams.offload_kqv) {
-            if (strcmp(name, "kqv_merged_cont") == 0) {
-                // all nodes between the KV store and the attention output are run on the CPU
-                ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, lctx.backend_cpu);
-            }
-        }
-
-        // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
-        // FIXME: fix in ggml_backend_sched
-        const bool full_offload = lctx.model.params.n_gpu_layers > (int) lctx.model.hparams.n_layer;
-        if (ubatch.n_tokens < 32 || full_offload) {
-            if (il != -1 && strcmp(name, "norm") == 0) {
-                const auto & dev_layer = lctx.model.dev_layer(il);
-                for (auto & backend : lctx.backends) {
-                    if (ggml_backend_get_device(backend.get()) == dev_layer) {
-                        if (ggml_backend_supports_op(backend.get(), cur)) {
-                            ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, backend.get());
-                        }
-                    }
-                }
-            }
-        }
-    };
-
-    struct ggml_cgraph * result = NULL;
-
-    struct llm_build_context llm(lctx, ubatch, cb, worst_case);
-
-    llm.init();
-
-    switch (model.arch) {
-        case LLM_ARCH_LLAMA:
-        case LLM_ARCH_MINICPM:
-        case LLM_ARCH_GRANITE:
-        case LLM_ARCH_GRANITE_MOE:
-            {
-                result = llm.build_llama();
-            } break;
-        case LLM_ARCH_MLLAMA:
-            {
-                result = llm.build_mllama();
-            } break;
-        case LLM_ARCH_DECI:
-            {
-                result = llm.build_deci();
-            } break;
-        case LLM_ARCH_BAICHUAN:
-            {
-                result = llm.build_baichuan();
-            } break;
-        case LLM_ARCH_FALCON:
-            {
-                result = llm.build_falcon();
-            } break;
-        case LLM_ARCH_GROK:
-            {
-                result = llm.build_grok();
-            } break;
-        case LLM_ARCH_STARCODER:
-            {
-                result = llm.build_starcoder();
-            } break;
-        case LLM_ARCH_REFACT:
-            {
-                result = llm.build_refact();
-            } break;
-        case LLM_ARCH_BERT:
-        case LLM_ARCH_JINA_BERT_V2:
-        case LLM_ARCH_NOMIC_BERT:
-            {
-                result = llm.build_bert();
-            } break;
-        case LLM_ARCH_BLOOM:
-            {
-                result = llm.build_bloom();
-            } break;
-        case LLM_ARCH_MPT:
-            {
-                result = llm.build_mpt();
-            } break;
-         case LLM_ARCH_STABLELM:
-            {
-                result = llm.build_stablelm();
-            } break;
-        case LLM_ARCH_QWEN:
-            {
-                result = llm.build_qwen();
-            } break;
-        case LLM_ARCH_QWEN2:
-            {
-                result = llm.build_qwen2();
-            } break;
-        case LLM_ARCH_QWEN2VL:
-            {
-                lctx.n_pos_per_token = 4;
-                result = llm.build_qwen2vl();
-            } break;
-        case LLM_ARCH_QWEN2MOE:
-            {
-                result = llm.build_qwen2moe();
-            } break;
-        case LLM_ARCH_PHI2:
-            {
-                result = llm.build_phi2();
-            } break;
-        case LLM_ARCH_PHI3:
-        case LLM_ARCH_PHIMOE:
-            {
-                result = llm.build_phi3();
-            } break;
-        case LLM_ARCH_PLAMO:
-            {
-                result = llm.build_plamo();
-            } break;
-        case LLM_ARCH_GPT2:
-            {
-                result = llm.build_gpt2();
-            } break;
-        case LLM_ARCH_CODESHELL:
-            {
-                result = llm.build_codeshell();
-            } break;
-        case LLM_ARCH_ORION:
-            {
-                result = llm.build_orion();
-            } break;
-        case LLM_ARCH_INTERNLM2:
-            {
-                result = llm.build_internlm2();
-            } break;
-        case LLM_ARCH_MINICPM3:
-            {
-                result = llm.build_minicpm3();
-            } break;
-        case LLM_ARCH_GEMMA:
-            {
-                result = llm.build_gemma();
-            } break;
-        case LLM_ARCH_GEMMA2:
-            {
-                result = llm.build_gemma2();
-            } break;
-        case LLM_ARCH_STARCODER2:
-            {
-                result = llm.build_starcoder2();
-            } break;
-        case LLM_ARCH_MAMBA:
-            {
-                result = llm.build_mamba();
-            } break;
-        case LLM_ARCH_XVERSE:
-            {
-                result = llm.build_xverse();
-            } break;
-        case LLM_ARCH_COMMAND_R:
-            {
-                result = llm.build_command_r();
-            } break;
-        case LLM_ARCH_COHERE2:
-            {
-                result = llm.build_cohere2();
-            } break;
-        case LLM_ARCH_DBRX:
-            {
-                result = llm.build_dbrx();
-            } break;
-        case LLM_ARCH_OLMO:
-            {
-                result = llm.build_olmo();
-            } break;
-        case LLM_ARCH_OLMO2:
-            {
-                result = llm.build_olmo2();
-            } break;
-        case LLM_ARCH_OLMOE:
-            {
-                result = llm.build_olmoe();
-            } break;
-        case LLM_ARCH_OPENELM:
-            {
-                result = llm.build_openelm();
-            } break;
-        case LLM_ARCH_GPTNEOX:
-            {
-                result = llm.build_gptneox();
-            } break;
-        case LLM_ARCH_ARCTIC:
-            {
-                result = llm.build_arctic();
-            } break;
-        case LLM_ARCH_DEEPSEEK:
-            {
-                result = llm.build_deepseek();
-            } break;
-        case LLM_ARCH_DEEPSEEK2:
-            {
-                result = llm.build_deepseek2();
-            } break;
-        case LLM_ARCH_CHATGLM:
-            {
-                result = llm.build_chatglm();
-            } break;
-        case LLM_ARCH_BITNET:
-            {
-                result = llm.build_bitnet();
-            } break;
-        case LLM_ARCH_T5:
-            {
-                if (lctx.is_encoding) {
-                    result = llm.build_t5_enc();
-                } else {
-                    result = llm.build_t5_dec();
-                }
-            } break;
-        case LLM_ARCH_T5ENCODER:
-            {
-                result = llm.build_t5_enc();
-            } break;
-        case LLM_ARCH_JAIS:
-            {
-                result = llm.build_jais();
-            } break;
-        case LLM_ARCH_NEMOTRON:
-            {
-                result = llm.build_nemotron();
-            } break;
-        case LLM_ARCH_EXAONE:
-            {
-                result = llm.build_exaone();
-            } break;
-        case LLM_ARCH_RWKV6:
-            {
-                result = llm.build_rwkv6();
-            } break;
-        case LLM_ARCH_RWKV6QWEN2:
-            {
-                result = llm.build_rwkv6qwen2();
-            } break;
-        case LLM_ARCH_CHAMELEON:
-            {
-                result = llm.build_chameleon();
-            } break;
-        case LLM_ARCH_SOLAR:
-            {
-                result = llm.build_solar();
-            } break;
-        case LLM_ARCH_WAVTOKENIZER_DEC:
-            {
-                result = llm.build_wavtokenizer_dec();
-            } break;
-        default:
-            GGML_ABORT("fatal error");
-    }
-
-    // add on pooling layer
-    if (lctx.cparams.embeddings) {
-        result = llm.append_pooling(result);
-    }
-
-    llm.free();
-
-    return result;
-}
-
-// returns the result of ggml_backend_sched_graph_compute_async execution
-static enum ggml_status llama_graph_compute(
-          llama_context & lctx,
-            ggml_cgraph * gf,
-                    int   n_threads,
-        ggml_threadpool * threadpool) {
-    if (lctx.backend_cpu != nullptr) {
-        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(lctx.backend_cpu));
-        auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
-        set_threadpool_fn(lctx.backend_cpu, threadpool);
-    }
-
-    // set the number of threads for all the backends
-    for (const auto & set_n_threads_fn : lctx.set_n_threads_fns) {
-        set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
-    }
-
-    auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
-    if (status != GGML_STATUS_SUCCESS) {
-        LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
-    }
-
-    // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
-
-    return status;
-}
-
-static int llama_prepare_sbatch(
-        llama_context     & lctx,
-        const llama_batch & batch,
-        uint32_t          & n_outputs) {
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    const uint32_t n_tokens_all = batch.n_tokens;
-    const  int64_t n_embd       = hparams.n_embd;
-
-    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
-    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
-    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
-    if (batch.token) {
-        for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
-                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
-                return -1;
-            }
-        }
-    }
-    GGML_ASSERT(n_tokens_all <= cparams.n_batch);
-    GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
-
-    lctx.n_queued_tokens += n_tokens_all;
-    lctx.embd_seq.clear();
-
-    // count outputs
-    if (batch.logits && !embd_pooled) {
-        for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            n_outputs += batch.logits[i] != 0;
-        }
-    } else if (lctx.logits_all || embd_pooled) {
-        n_outputs = n_tokens_all;
-    } else {
-        // keep last output only
-        n_outputs = 1;
-    }
-
-    lctx.sbatch.from_batch(batch, batch.n_embd,
-        /* simple_split */ !lctx.kv_self.recurrent,
-        /* logits_all   */ n_outputs == n_tokens_all);
-
-    // reserve output buffer
-    if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
-        LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
-        return -2;
-    };
-
-    return 0;
-}
-
-static int llama_prepare_ubatch(
-        llama_context          & lctx,
-        llama_kv_slot_restorer & kv_slot_restorer,
-        llama_ubatch           & ubatch,
-        const uint32_t           n_outputs,
-        const uint32_t           n_tokens_all) {
-    GGML_ASSERT(lctx.sbatch.n_tokens > 0);
-
-    auto       & kv_self = lctx.kv_self;
-    const auto & cparams = lctx.cparams;
-    const auto & hparams = lctx.model.hparams;
-
-    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
-    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
-    if (lctx.kv_self.recurrent) {
-        if (embd_pooled) {
-            // Pooled embeddings cannot be split across ubatches (yet)
-            ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
-        } else {
-            // recurrent model architectures are easier to implement
-            // with equal-length sequences
-            ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
-        }
-    } else {
-        ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
-    }
-
-    // count the outputs in this u_batch
-    {
-        int32_t n_outputs_new = 0;
-
-        if (n_outputs == n_tokens_all) {
-            n_outputs_new = ubatch.n_tokens;
-        } else {
-            GGML_ASSERT(ubatch.output);
-            for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
-                n_outputs_new += int32_t(ubatch.output[i] != 0);
-            }
-        }
-
-        // needs to happen before the graph is built
-        lctx.n_outputs = n_outputs_new;
-    }
-
-    // non-causal masks do not use the KV cache
-    if (hparams.causal_attn) {
-        llama_kv_cache_update(&lctx);
-
-        // if we have enough unused cells before the current head ->
-        //   better to start searching from the beginning of the cache, hoping to fill it
-        if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
-            kv_self.head = 0;
-        }
-
-        auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-        if (!slot) {
-            llama_kv_cache_defrag(kv_self);
-            llama_kv_cache_update(&lctx);
-            slot = llama_kv_cache_find_slot(kv_self, ubatch);
-        }
-        if (!slot) {
-            return 1;
-        }
-        kv_slot_restorer.save(slot);
-
-        if (!kv_self.recurrent) {
-            // a heuristic, to avoid attending the full cache if it is not yet utilized
-            // after enough generations, the benefit from this heuristic disappears
-            // if we start defragmenting the cache, the benefit from this will be more important
-            const uint32_t pad = llama_kv_cache_get_padding(cparams);
-            kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
-            //kv_self.n = llama_kv_cache_cell_max(kv_self);
-        }
-    }
-
-    return 0;
-}
-
-// decode a batch of tokens by evaluating the transformer
-// in case of unsuccessful decoding (error or warning),
-// the kv_cache state will be returned to its original state
-// (for non-recurrent models) or cleaned (for recurrent models)
-//
-//   - lctx:      llama context
-//   - inp_batch: batch to evaluate
-//
-// return 0 on success
-// return positive int on warning
-// return negative int on error
-//
-static int llama_decode_impl(
-         llama_context & lctx,
-           llama_batch   inp_batch) {
-
-    lctx.is_encoding = false;
-
-    if (inp_batch.n_tokens == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
-        return -1;
-    }
-
-    // temporarily allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
-    const llama_batch & batch = batch_allocr.batch;
-
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    if (lctx.t_compute_start_us == 0) {
-        lctx.t_compute_start_us = ggml_time_us();
-    }
-    auto & kv_self = lctx.kv_self;
-    llama_kv_slot_restorer kv_slot_restorer(kv_self);
-
-    const int64_t n_embd  = hparams.n_embd;
-    const int64_t n_vocab = hparams.n_vocab;
-
-    uint32_t n_outputs = 0;
-    uint32_t n_outputs_prev = 0;
-
-    {
-        const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
-        if (ret != 0) {
-            return ret;
-        }
-    }
-
-    while (lctx.sbatch.n_tokens > 0) {
-        llama_ubatch ubatch;
-        {
-            const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
-            if (ret != 0) {
-                return ret;
-            }
-        }
-
-        const int         n_threads  = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
-        ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool   : lctx.threadpool_batch;
-
-        GGML_ASSERT(n_threads > 0);
-
-        //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
-
-        ggml_backend_sched_reset(lctx.sched.get());
-        ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
-
-        ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
-
-        // the output is always the last tensor in the graph
-        struct ggml_tensor * res  = ggml_graph_node(gf, -1);
-        struct ggml_tensor * embd = ggml_graph_node(gf, -2);
-
-        if (lctx.n_outputs == 0) {
-            // no output
-            res  = nullptr;
-            embd = nullptr;
-        } else if (cparams.embeddings) {
-            embd = nullptr;
-            for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
-                if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
-                    embd = ggml_graph_node(gf, i);
-                    break;
-                }
-            }
-        } else {
-            embd = nullptr; // do not extract embeddings when not needed
-            GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
-        }
-
-        if (!cparams.causal_attn) {
-            res = nullptr; // do not extract logits when not needed
-        }
-
-        // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
-
-        ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
-
-        llama_set_inputs(lctx, ubatch);
-
-        const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
-        if (compute_status != GGML_STATUS_SUCCESS) {
-            kv_slot_restorer.restore(kv_self);
-            switch (compute_status) {
-                case GGML_STATUS_ABORTED:
-                    return 2;
-                case GGML_STATUS_ALLOC_FAILED:
-                    return -2;
-                case GGML_STATUS_FAILED:
-                default:
-                    return -3;
-            }
-        }
-
-        // update the kv ring buffer
-        {
-            kv_self.head += ubatch.n_tokens;
-
-            // Ensure kv cache head points to a valid index.
-            if (kv_self.head >= kv_self.size) {
-                kv_self.head = 0;
-            }
-        }
-
-        // plot the computation graph in dot format (for debugging purposes)
-        //if (n_past%100 == 0) {
-        //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
-        //}
-
-        // extract logits
-        if (res) {
-            ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), res);
-            GGML_ASSERT(backend_res != nullptr);
-            GGML_ASSERT(lctx.logits != nullptr);
-
-            float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
-            const int32_t n_outputs_new = lctx.n_outputs;
-
-            if (n_outputs_new) {
-                GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
-                GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
-                ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
-            }
-        }
-
-        // extract embeddings
-        if (embd) {
-            ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), embd);
-            GGML_ASSERT(backend_embd != nullptr);
-
-            switch (cparams.pooling_type) {
-                case LLAMA_POOLING_TYPE_NONE:
-                    {
-                        // extract token embeddings
-                        GGML_ASSERT(lctx.embd != nullptr);
-                        float * embd_out = lctx.embd + n_outputs_prev*n_embd;
-                        const int32_t n_outputs_new = lctx.n_outputs;
-
-                        if (n_outputs_new) {
-                            GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
-                            GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
-                            ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
-                        }
-                    } break;
-                case LLAMA_POOLING_TYPE_MEAN:
-                case LLAMA_POOLING_TYPE_CLS:
-                case LLAMA_POOLING_TYPE_LAST:
-                    {
-                        // extract sequence embeddings (cleared before processing each batch)
-                        auto & embd_seq_out = lctx.embd_seq;
-
-                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
-                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                                continue;
-                            }
-                            embd_seq_out[seq_id].resize(n_embd);
-                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
-                        }
-                    } break;
-                case LLAMA_POOLING_TYPE_RANK:
-                    {
-                        // extract the rerank score - a single float per sequence
-                        auto & embd_seq_out = lctx.embd_seq;
-
-                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
-                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                                continue;
-                            }
-                            embd_seq_out[seq_id].resize(1);
-                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
-                        }
-                    } break;
-                case LLAMA_POOLING_TYPE_UNSPECIFIED:
-                    {
-                        GGML_ABORT("unknown pooling type");
-                    }
-            }
-        }
-        n_outputs_prev += lctx.n_outputs;
-    }
-
-    // set output mappings
-    {
-        bool sorted_output = true;
-
-        GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
-
-        for (size_t i = 0; i < n_outputs; ++i) {
-            size_t out_id = lctx.sbatch.out_ids[i];
-            lctx.output_ids[out_id] = i;
-            if (out_id != i) {
-                sorted_output = false;
-            }
-        }
-
-        if (sorted_output) {
-            lctx.sbatch.out_ids.clear();
-        }
-    }
-
-    // set to total number of outputs in the batch, for use in llama_get_logits_ith
-    lctx.n_outputs = n_outputs;
-
-    // wait for the computation to finish (automatically done when obtaining the model output)
-    //llama_synchronize(&lctx);
-
-    // decide if we need to defrag the kv cache
-    if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
-        // - do not defrag small contexts (i.e. < 2048 tokens)
-        // - count the padding towards the number of used tokens
-        const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + llama_kv_cache_get_padding(cparams))/float(kv_self.n)) : 0.0f;
-
-        // queue defragmentation for next llama_kv_cache_update
-        if (fragmentation > cparams.defrag_thold) {
-            LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
-
-            llama_kv_cache_defrag(kv_self);
-        }
-    }
-
-    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-    // overlap with device computation.
-    ggml_backend_sched_reset(lctx.sched.get());
-
-    return 0;
-}
-
-// encode a batch of tokens by evaluating the encoder part of the transformer
-//
-//   - lctx:      llama context
-//   - batch:     batch to evaluate
-//
-// return 0 on success
-// return positive int on warning
-// return negative int on error
-//
-static int llama_encode_impl(
-         llama_context & lctx,
-           llama_batch   inp_batch) {
-
-    lctx.is_encoding = true;
-
-    if (inp_batch.n_tokens == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
-        return -1;
-    }
-
-    // temporary allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
-
-    const llama_batch & batch = batch_allocr.batch;
-    const uint32_t n_tokens = batch.n_tokens;
-
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
-
-    if (batch.token) {
-        for (uint32_t i = 0; i < n_tokens; ++i) {
-            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
-                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
-                return -1;
-            }
-        }
-    }
-
-    // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
-    GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
-
-    if (lctx.t_compute_start_us == 0) {
-        lctx.t_compute_start_us = ggml_time_us();
-    }
-
-    lctx.n_queued_tokens += n_tokens;
-
-    const int64_t n_embd = hparams.n_embd;
-
-    lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
-
-    const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
-
-    // reserve output buffer
-    if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
-        LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
-        return -2;
-    };
-
-    for (uint32_t i = 0; i < n_tokens; ++i) {
-        lctx.output_ids[i] = i;
-    }
-
-    lctx.inp_embd_enc = NULL;
-    lctx.n_outputs = n_tokens;
-
-    int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
-    ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
-
-    GGML_ASSERT(n_threads > 0);
-
-    ggml_backend_sched_reset(lctx.sched.get());
-    ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
-
-    ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
-
-    // the output embeddings after the final encoder normalization
-    struct ggml_tensor * embd = nullptr;
-
-    // there are two cases here
-    if (llama_model_has_decoder(&lctx.model)) {
-        // first case is an encoder-decoder T5 model where embeddings are passed to decoder
-        embd = ggml_graph_node(gf, -1);
-        GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor");
-    } else {
-        // second case is an encoder-only T5 model
-        if (cparams.embeddings) {
-            // only output embeddings if required
-            embd = ggml_graph_node(gf, -1);
-            if (strcmp(embd->name, "result_embd_pooled") != 0) {
-                embd = ggml_graph_node(gf, -2);
-            }
-            GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
-        }
-    }
-
-    ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
-
-    llama_set_inputs(lctx, ubatch);
-
-    const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
-    switch (compute_status) {
-        case GGML_STATUS_SUCCESS:
-            break;
-        case GGML_STATUS_ABORTED:
-            return 2;
-        case GGML_STATUS_ALLOC_FAILED:
-            return -2;
-        case GGML_STATUS_FAILED:
-        default:
-            return -3;
-    }
-
-    // extract embeddings
-    if (embd) {
-        ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), embd);
-        GGML_ASSERT(backend_embd != nullptr);
-
-        if (llama_model_has_decoder(&lctx.model)) {
-            lctx.embd_enc.resize(n_tokens*n_embd);
-            float * embd_out = lctx.embd_enc.data();
-
-            ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
-            GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
-
-            // remember the sequence ids used during the encoding - needed for cross attention later
-            lctx.seq_ids_enc.resize(n_tokens);
-            for (uint32_t i = 0; i < n_tokens; i++) {
-                for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
-                    llama_seq_id seq_id = ubatch.seq_id[i][s];
-                    lctx.seq_ids_enc[i].insert(seq_id);
-                }
-            }
-        } else {
-            GGML_ASSERT(lctx.embd != nullptr);
-
-            switch (cparams.pooling_type) {
-                case LLAMA_POOLING_TYPE_NONE:
-                    {
-                        // extract token embeddings
-                        GGML_ASSERT(lctx.embd != nullptr);
-                        float * embd_out = lctx.embd;
-
-                        GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
-                        ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
-                    } break;
-                case LLAMA_POOLING_TYPE_MEAN:
-                case LLAMA_POOLING_TYPE_CLS:
-                case LLAMA_POOLING_TYPE_LAST:
-                    {
-                        // extract sequence embeddings
-                        auto & embd_seq_out = lctx.embd_seq;
-                        embd_seq_out.clear();
-
-                        GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
-
-                        for (uint32_t i = 0; i < n_tokens; i++) {
-                            const llama_seq_id seq_id = ubatch.seq_id[i][0];
-                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                                continue;
-                            }
-                            embd_seq_out[seq_id].resize(n_embd);
-                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
-                        }
-                    } break;
-                case LLAMA_POOLING_TYPE_RANK:
-                    {
-                        // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
-                        //       wait for an encoder model that requires this pooling type in order to test it
-                        //       https://github.com/ggerganov/llama.cpp/pull/9510
-                        GGML_ABORT("RANK pooling not implemented yet");
-                    }
-                case LLAMA_POOLING_TYPE_UNSPECIFIED:
-                    {
-                        GGML_ABORT("unknown pooling type");
-                    }
-            }
-        }
-    }
-
-    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-    // overlap with device computation.
-    ggml_backend_sched_reset(lctx.sched.get());
-
-    return 0;
-}
-
-// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
-static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
-    auto & kv_self = lctx.kv_self;
-
-    const auto & hparams = lctx.model.hparams;
-
-    const uint32_t n_layer = hparams.n_layer;
-
-    const uint32_t n_kv   = llama_kv_cache_cell_max(kv_self);
-    const uint32_t n_used = kv_self.used;
-
-    assert(n_used <= n_kv);
-
-    //const int64_t t_start = ggml_time_us();
-
-    // groups of cells moved
-    std::vector moves;
-
-    // each move requires 6*n_layer tensors (see build_defrag)
-    //   - source view, destination view, copy operation
-    //   - x2 for keys and values
-    //const uint32_t max_moves = model.max_nodes()/(6*n_layer);
-    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (lctx.model.max_nodes() - 2*n_layer)/(6*n_layer);
-
-    // determine which KV cells to move where
-    //
-    //  cell i moves to ids[i]
-    //
-    //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
-    //
-    std::vector ids(n_kv, n_kv);
-
-    for (uint32_t i0 = 0; i0 < n_used; ++i0) {
-        const auto & cell0 = kv_self.cells[i0];
-
-        if (!cell0.is_empty()) {
-            ids[i0] = i0;
-
-            continue;
-        }
-
-        // found a hole - fill it with data from the end of the cache
-
-        uint32_t nh = 1;
-
-        // determine the size of the hole
-        while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
-            nh++;
-        }
-
-        uint32_t nf = 0;
-        uint32_t is = n_kv - 1;
-
-        // starting from the end, find nh non-empty cells
-        for (; is > i0; --is) {
-            const auto & cell1 = kv_self.cells[is];
-
-            if (cell1.is_empty() || ids[is] != n_kv) {
-                continue;
-            }
-
-            // non-empty cell which is not yet moved
-            nf++;
-
-            if (nf == nh) {
-                break;
-            }
-        }
-
-        // this can only happen if `n_used` is not accurate, which would be a bug
-        GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
-
-        nf = 0;
-
-        uint32_t i1 = is;
-
-        // are we moving a continuous block of memory?
-        bool cont = false;
-
-        // go back and move the nf cells to the hole
-        for (; i1 < n_kv; ++i1) {
-            auto & cell1 = kv_self.cells[i1];
-
-            if (cell1.is_empty() || ids[i1] != n_kv) {
-                cont = false;
-                continue;
-            }
-
-            // this cell goes to (i0 + nf)
-            ids[i1] = i0 + nf;
-
-            // move the cell meta data
-            kv_self.cells[i0 + nf] = cell1;
-
-            // clear the old cell and move the head there
-            cell1 = llama_kv_cell();
-            kv_self.head = n_used;
-
-            if (!cont) {
-                moves.push_back({i1, i0 + nf, 1});
-                cont = true;
-            } else {
-                moves.back().len++;
-            }
-
-            nf++;
-
-            if (nf == nh) {
-                break;
-            }
-        }
-
-        //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
-
-        i0 += nh - 1;
-    }
-
-    if (moves.size() == 0) {
-        return;
-    }
-
-    //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n",  moves.size());
-
-#if 0
-    // CPU defrag
-    //
-    // TODO: optimizations are possible:
-    //       - multiple threads
-    //       - avoid copying to the host memory when already there
-    //
-    // likely not worth the effort, as we have ggml_graph based defrag
-    //
-
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
-
-    const uint32_t kv_size = kv_self.size;
-
-    std::vector buf_k;
-    std::vector buf_v;
-
-    for (uint32_t il = 0; il < n_layer; ++il) {
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        const size_t k_size     = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_size);
-
-        const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-        const size_t v_size    = ggml_row_size (kv_self.v_l[il]->type, n_embd_v_gqa*kv_size);
-
-        buf_k.resize(k_size);
-        buf_v.resize(v_size);
-
-        ggml_backend_tensor_get(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_get(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
-
-        // batch move [i, i+nm) to [id, id+nm)
-        // note: cells can move only to a lower index
-        for (uint32_t i = 0; i < n_kv; ++i) {
-            const uint32_t id = ids[i];
-
-            if (i == id || id == n_kv) {
-                continue;
-            }
-
-            uint32_t nm = 1;
-
-            while (i + nm < n_kv && ids[i + nm] == id + nm) {
-                nm++;
-            }
-
-            // move keys
-            {
-                const int64_t os =  i*k_size_row;
-                const int64_t od = id*k_size_row;
-
-                memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
-            }
-
-            // move values (note: they are transposed)
-            {
-                const int64_t os =  i;
-                const int64_t od = id;
-
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
-                }
-            }
-
-            i += nm - 1;
-        }
-
-        ggml_backend_tensor_set(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_set(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
-    }
-#else
-    // ggml_graph defrag
-
-    for (std::size_t i = 0; i < moves.size(); i += max_moves) {
-        std::vector chunk;
-        auto end = std::min(i + max_moves, moves.size());
-        chunk.assign(moves.begin() + i, moves.begin() + end);
-
-        ggml_backend_sched_reset(lctx.sched.get());
-
-        //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
-        ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
-
-        llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
-    }
-#endif
-
-    //const int64_t t_end = ggml_time_us();
-
-    //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
-}
-
-static void llama_kv_cache_update_impl(struct llama_context & lctx) {
-    bool need_reserve = false;
-
-    if (lctx.kv_self.has_shift) {
-        if (!llama_kv_cache_can_shift(&lctx)) {
-            GGML_ABORT("The current context does not support K-shift");
-        }
-
-        // apply K-shift if needed
-        if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
-            ggml_backend_sched_reset(lctx.sched.get());
-
-            ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
-
-            ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
-
-            llama_set_k_shift(lctx);
-
-            llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
-
-            need_reserve = true;
-        }
-
-        {
-            auto & kv_self = lctx.kv_self;
-
-            kv_self.has_shift = false;
-
-            for (uint32_t i = 0; i < kv_self.size; ++i) {
-                kv_self.cells[i].delta = 0;
-            }
-        }
-    }
-
-    // defragment the KV cache if needed
-    if (lctx.kv_self.do_defrag) {
-        llama_kv_cache_defrag_impl(lctx);
-
-        need_reserve = true;
-
-        lctx.kv_self.do_defrag = false;
-    }
-
-    // reserve a worst case graph again
-    if (need_reserve) {
-        // TODO: extract to a function
-        // build worst-case graph
-        uint32_t n_seqs = 1; // TODO: worst-case number of sequences
-        uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
-        llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-        llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-        ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
-
-        // initialize scheduler with the worst-case graph
-        ggml_backend_sched_reset(lctx.sched.get());
-        if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
-            LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
-        }
-    }
-}
-
-int32_t llama_set_adapter_lora(
-            struct llama_context * ctx,
-            struct llama_adapter_lora * adapter,
-            float scale) {
-    ctx->lora[adapter] = scale;
-    return 0;
-}
-
-int32_t llama_rm_adapter_lora(
-            struct llama_context * ctx,
-            struct llama_adapter_lora * adapter) {
-    auto pos = ctx->lora.find(adapter);
-    if (pos != ctx->lora.end()) {
-        ctx->lora.erase(pos);
-        return 0;
-    }
-
-    return -1;
-}
-
-void llama_clear_adapter_lora(struct llama_context * ctx) {
-    ctx->lora.clear();
-}
-
-int32_t llama_apply_adapter_cvec(
-        struct llama_context * ctx,
-                 const float * data,
-                      size_t   len,
-                     int32_t   n_embd,
-                     int32_t   il_start,
-                     int32_t   il_end) {
-    return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
-}
-
 //
 // interface implementation
 //
 
-struct llama_context_params llama_context_default_params() {
-    struct llama_context_params result = {
-        /*.n_ctx                       =*/ 512,
-        /*.n_batch                     =*/ 2048,
-        /*.n_ubatch                    =*/ 512,
-        /*.n_seq_max                   =*/ 1,
-        /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
-        /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
-        /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
-        /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
-        /*.attention_type              =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
-        /*.rope_freq_base              =*/ 0.0f,
-        /*.rope_freq_scale             =*/ 0.0f,
-        /*.yarn_ext_factor             =*/ -1.0f,
-        /*.yarn_attn_factor            =*/ 1.0f,
-        /*.yarn_beta_fast              =*/ 32.0f,
-        /*.yarn_beta_slow              =*/ 1.0f,
-        /*.yarn_orig_ctx               =*/ 0,
-        /*.defrag_thold                =*/ -1.0f,
-        /*.cb_eval                     =*/ nullptr,
-        /*.cb_eval_user_data           =*/ nullptr,
-        /*.type_k                      =*/ GGML_TYPE_F16,
-        /*.type_v                      =*/ GGML_TYPE_F16,
-        /*.logits_all                  =*/ false,
-        /*.embeddings                  =*/ false,
-        /*.offload_kqv                 =*/ true,
-        /*.flash_attn                  =*/ false,
-        /*.no_perf                     =*/ true,
-        /*.cross_attn                  =*/ false,
-        /*.abort_callback              =*/ nullptr,
-        /*.abort_callback_data         =*/ nullptr,
-    };
-
-    return result;
-}
-
 struct llama_sampler_chain_params llama_sampler_chain_default_params() {
     struct llama_sampler_chain_params result = {
         /*.no_perf                     =*/ true,
@@ -9823,6 +82,57 @@ int64_t llama_time_us(void) {
     return ggml_time_us();
 }
 
+// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
+static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) {
+    // loading time will be recalculated after the first eval, so
+    // we take page faults deferred by mmap() into consideration
+    model.t_load_us = 0;
+    time_meas tm(model.t_load_us);
+
+    model.t_start_us = tm.t_start_us;
+
+    try {
+        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides);
+
+        ml.print_info();
+
+        model.hparams.vocab_only = params.vocab_only;
+
+        try {
+            model.load_arch(ml);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
+        }
+        try {
+            model.load_hparams(ml);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
+        }
+        try {
+            model.load_vocab(ml);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
+        }
+
+        model.load_stats(ml);
+        model.print_info();
+
+        if (params.vocab_only) {
+            LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
+            return 0;
+        }
+
+        if (!model.load_tensors(ml)) {
+            return -2;
+        }
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
+        return -1;
+    }
+
+    return 0;
+}
+
 static struct llama_model * llama_model_load_from_file_impl(
         const std::string & path_model,
         std::vector & splits,
@@ -9943,460 +253,6 @@ struct llama_model * llama_model_load_from_splits(
     return llama_model_load_from_file_impl(splits.front(), splits, params);
 }
 
-struct llama_context * llama_init_from_model(
-                 struct llama_model * model,
-        struct llama_context_params   params) {
-
-    if (!model) {
-        LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
-        return nullptr;
-    }
-
-    if (params.n_batch == 0 && params.n_ubatch == 0) {
-        LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
-        return nullptr;
-    }
-
-    if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
-        LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
-        return nullptr;
-    }
-
-    if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
-        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
-        params.flash_attn = false;
-    }
-
-    if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
-        LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
-        params.flash_attn = false;
-    }
-
-    if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
-        LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
-        return nullptr;
-    }
-
-    llama_context * ctx = new llama_context(*model);
-
-    const auto & hparams = model->hparams;
-    auto       & cparams = ctx->cparams;
-
-    cparams.n_seq_max        = std::max(1u, params.n_seq_max);
-    cparams.n_threads        = params.n_threads;
-    cparams.n_threads_batch  = params.n_threads_batch;
-    cparams.yarn_ext_factor  = params.yarn_ext_factor;
-    cparams.yarn_attn_factor = params.yarn_attn_factor;
-    cparams.yarn_beta_fast   = params.yarn_beta_fast;
-    cparams.yarn_beta_slow   = params.yarn_beta_slow;
-    cparams.defrag_thold     = params.defrag_thold;
-    cparams.embeddings       = params.embeddings;
-    cparams.offload_kqv      = params.offload_kqv;
-    cparams.flash_attn       = params.flash_attn;
-    cparams.no_perf          = params.no_perf;
-    cparams.pooling_type     = params.pooling_type;
-
-    cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
-    cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
-    cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
-
-    // this is necessary due to kv_self.n being padded later during inference
-    cparams.n_ctx            = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
-
-    // with causal attention, the batch size is limited by the context size
-    cparams.n_batch          = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
-
-    // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
-    // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
-    // ref: https://github.com/ggerganov/llama.cpp/pull/5021
-    if (cparams.n_batch < GGML_KQ_MASK_PAD) {
-        LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
-        cparams.n_batch = GGML_KQ_MASK_PAD;
-    }
-
-    cparams.n_ubatch         = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
-
-    cparams.n_ctx_orig_yarn  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
-                               hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
-                                                              hparams.n_ctx_train;
-
-    cparams.cb_eval           = params.cb_eval;
-    cparams.cb_eval_user_data = params.cb_eval_user_data;
-
-    auto rope_scaling_type = params.rope_scaling_type;
-    if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
-        rope_scaling_type = hparams.rope_scaling_type_train;
-    }
-
-    if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
-        cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
-    }
-
-    if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
-        cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
-    }
-
-    cparams.yarn_attn_factor *= hparams.rope_attn_factor;
-
-    if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
-        if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
-            cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
-        } else {
-            cparams.pooling_type = hparams.pooling_type;
-        }
-    }
-
-    if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
-        cparams.causal_attn = hparams.causal_attn;
-    } else {
-        cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
-    }
-
-    const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
-
-    LLAMA_LOG_INFO("%s: n_seq_max     = %u\n",   __func__, cparams.n_seq_max);
-    LLAMA_LOG_INFO("%s: n_ctx         = %u\n",   __func__, cparams.n_ctx);
-    LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n",   __func__, n_ctx_per_seq);
-    LLAMA_LOG_INFO("%s: n_batch       = %u\n",   __func__, cparams.n_batch);
-    LLAMA_LOG_INFO("%s: n_ubatch      = %u\n",   __func__, cparams.n_ubatch);
-    LLAMA_LOG_INFO("%s: flash_attn    = %d\n",   __func__, cparams.flash_attn);
-    LLAMA_LOG_INFO("%s: freq_base     = %.1f\n", __func__, cparams.rope_freq_base);
-    LLAMA_LOG_INFO("%s: freq_scale    = %g\n",   __func__, cparams.rope_freq_scale);
-
-    if (n_ctx_per_seq < hparams.n_ctx_train) {
-        LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
-                __func__, n_ctx_per_seq, hparams.n_ctx_train);
-    }
-
-    if (n_ctx_per_seq > hparams.n_ctx_train) {
-        LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
-                __func__, n_ctx_per_seq, hparams.n_ctx_train);
-    }
-
-    ctx->logits_all = params.logits_all;
-
-    // build worst-case graph for encoder if a model contains encoder
-    ctx->is_encoding = llama_model_has_encoder(model);
-
-    uint32_t kv_size = cparams.n_ctx;
-    ggml_type type_k = params.type_k;
-    ggml_type type_v = params.type_v;
-
-    // Mamba only needs a constant number of KV cache cells per sequence
-    if (llama_model_is_recurrent(model)) {
-        // Mamba needs at least as many KV cells as there are sequences kept at any time
-        kv_size = std::max((uint32_t) 1, params.n_seq_max);
-        // it's probably best to keep as much precision as possible for the states
-        type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
-        type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
-    }
-
-    GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
-    GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
-
-    if (!hparams.vocab_only) {
-        // GPU backends
-        for (auto * dev : model->devices) {
-            ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.emplace_back(backend);
-        }
-
-        // add ACCEL backends (such as BLAS)
-        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-            if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
-                ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.emplace_back(backend);
-            }
-        }
-
-        // add CPU backend
-        ctx->backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
-        if (ctx->backend_cpu == nullptr) {
-            LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
-            llama_free(ctx);
-            return nullptr;
-        }
-        ctx->backends.emplace_back(ctx->backend_cpu);
-
-        // create a list of the set_n_threads functions in the backends
-        for (auto & backend : ctx->backends) {
-            ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
-            ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
-            if (reg) {
-                auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
-                if (ggml_backend_set_n_threads_fn) {
-                    ctx->set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
-                }
-            }
-        }
-
-        llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data);
-
-        if (!llama_kv_cache_init(ctx->kv_self, ctx->model, ctx->cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
-            LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
-            llama_free(ctx);
-            return nullptr;
-        }
-
-        {
-            size_t memory_size_k = 0;
-            size_t memory_size_v = 0;
-
-            for (auto & k : ctx->kv_self.k_l) {
-                memory_size_k += ggml_nbytes(k);
-            }
-
-            for (auto & v : ctx->kv_self.v_l) {
-                memory_size_v += ggml_nbytes(v);
-            }
-
-            LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                      (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
-                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
-                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
-        }
-
-        // graph outputs buffer
-        {
-            // resized during inference when a batch uses more outputs
-            if (llama_output_reserve(*ctx, params.n_seq_max) < params.n_seq_max) {
-                LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-
-            LLAMA_LOG_INFO("%s: %10s  output buffer size = %8.2f MiB\n", __func__,
-                    ggml_backend_buffer_name(ctx->buf_output.get()),
-                    ggml_backend_buffer_get_size(ctx->buf_output.get()) / 1024.0 / 1024.0);
-        }
-
-        // scheduler and compute buffers
-        {
-            // buffer types used for the compute buffer of each backend
-            std::vector backend_buft;
-            std::vector backend_ptrs;
-            for (auto & backend : ctx->backends) {
-                auto * buft = ggml_backend_get_default_buffer_type(backend.get());
-                auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
-                if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model->devices.empty()) {
-                    // use the host buffer of the first device CPU for faster transfer of the intermediate state
-                    auto * dev = model->devices[0];
-                    auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
-                    if (host_buft) {
-                        buft = host_buft;
-                    }
-                }
-                backend_buft.push_back(buft);
-                backend_ptrs.push_back(backend.get());
-            }
-
-            const size_t max_nodes = model->max_nodes();
-
-            // buffer used to store the computation graph and the tensor meta data
-            ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
-
-            // TODO: move these checks to ggml_backend_sched
-            // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
-            bool pipeline_parallel =
-                model->n_devices() > 1 &&
-                model->params.n_gpu_layers > (int)model->hparams.n_layer &&
-                model->params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
-                params.offload_kqv;
-
-            // pipeline parallelism requires support for async compute and events in all devices
-            if (pipeline_parallel) {
-                for (auto & backend : ctx->backends) {
-                    auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
-                    if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
-                        // ignore CPU backend
-                        continue;
-                    }
-                    auto * dev = ggml_backend_get_device(backend.get());
-                    ggml_backend_dev_props props;
-                    ggml_backend_dev_get_props(dev, &props);
-                    if (!props.caps.async || !props.caps.events) {
-                        // device does not support async compute or events
-                        pipeline_parallel = false;
-                        break;
-                    }
-                }
-            }
-
-            ctx->sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
-
-            if (pipeline_parallel) {
-                LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched.get()));
-            }
-
-            // initialize scheduler with the worst-case graph
-            uint32_t n_seqs = 1; // TODO: worst-case number of sequences
-            uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-            llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-
-            llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-            ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
-
-            // reserve pp graph first so that buffers are only allocated once
-            ggml_backend_sched_reserve(ctx->sched.get(), gf_pp);
-            int n_splits_pp = ggml_backend_sched_get_n_splits(ctx->sched.get());
-            int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
-
-            // reserve with tg graph to get the number of splits and nodes
-            llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-            ggml_cgraph * gf_tg = llama_build_graph(*ctx, ubatch_tg, true);
-            ggml_backend_sched_reserve(ctx->sched.get(), gf_tg);
-            int n_splits_tg = ggml_backend_sched_get_n_splits(ctx->sched.get());
-            int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
-
-            // reserve again with pp graph to avoid ggml-alloc reallocations during inference
-            gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
-            if (!ggml_backend_sched_reserve(ctx->sched.get(), gf_pp)) {
-                LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-
-            for (size_t i = 0; i < backend_ptrs.size(); ++i) {
-                ggml_backend_t backend = backend_ptrs[i];
-                ggml_backend_buffer_type_t buft = backend_buft[i];
-                size_t size = ggml_backend_sched_get_buffer_size(ctx->sched.get(), backend);
-                if (size > 1) {
-                    LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
-                            ggml_backend_buft_name(buft),
-                            size / 1024.0 / 1024.0);
-                }
-            }
-
-            if (n_nodes_pp == n_nodes_tg) {
-                LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, n_nodes_pp);
-            } else {
-                LLAMA_LOG_INFO("%s: graph nodes  = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
-            }
-            if (n_splits_pp == n_splits_tg) {
-                LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
-            } else {
-                LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
-            }
-        }
-    }
-
-    return ctx;
-}
-
-struct llama_context * llama_new_context_with_model(
-                 struct llama_model * model,
-        struct llama_context_params   params) {
-    return llama_init_from_model(model, params);
-}
-
-//
-// kv cache
-//
-
-// TODO: tmp bridges below until `struct llama_kv_cache` is exposed through the public API
-
-struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
-    return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
-}
-
-void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
-    llama_kv_cache_view_update(view, ctx->kv_self);
-}
-
-int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) {
-    return llama_get_kv_cache_token_count(ctx->kv_self);
-}
-
-int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
-    return llama_get_kv_cache_used_cells(ctx->kv_self);
-}
-
-void llama_kv_cache_clear(struct llama_context * ctx) {
-    llama_kv_cache_clear(ctx->kv_self);
-}
-
-bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
-}
-
-void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
-    if (seq_id_src == seq_id_dst) {
-        return;
-    }
-    llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
-}
-
-void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
-    llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
-}
-
-void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
-    if (delta == 0) {
-        return;
-    }
-
-    llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta);
-}
-
-void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
-    if (d == 1) {
-        return;
-    }
-
-    llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
-}
-
-llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
-    return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
-}
-
-void llama_kv_cache_defrag(struct llama_context * ctx) {
-    llama_kv_cache_defrag(ctx->kv_self);
-}
-
-void llama_kv_cache_update(struct llama_context * ctx) {
-    llama_kv_cache_update_impl(*ctx);
-}
-
-bool llama_kv_cache_can_shift(struct llama_context * ctx) {
-    return llama_kv_cache_can_shift(ctx->kv_self);
-}
-
-///
-
-int32_t llama_encode(
-        struct llama_context * ctx,
-          struct llama_batch   batch) {
-    const int ret = llama_encode_impl(*ctx, batch);
-    if (ret != 0) {
-        LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
-    }
-
-    return ret;
-}
-
-int32_t llama_decode(
-        struct llama_context * ctx,
-          struct llama_batch   batch) {
-    const int ret = llama_decode_impl(*ctx, batch);
-    if (ret != 0) {
-        LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
-    }
-
-    return ret;
-}
-
 //
 // chat templates
 //
@@ -10464,7 +320,6 @@ const char * llama_print_system_info(void) {
     static std::string s;
     s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls.
 
-
     for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
         auto * reg = ggml_backend_reg_get(i);
         auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
@@ -10483,43 +338,3 @@ const char * llama_print_system_info(void) {
 
     return s.c_str();
 }
-
-//
-// perf
-//
-
-struct llama_perf_context_data llama_perf_context(const struct llama_context * ctx) {
-    struct llama_perf_context_data data = {};
-
-    if (ctx == nullptr) {
-        return data;
-    }
-
-    data.t_start_ms  = 1e-3 * ctx->t_start_us;
-    data.t_load_ms   = 1e-3 * ctx->t_load_us;
-    data.t_p_eval_ms = 1e-3 * ctx->t_p_eval_us;
-    data.t_eval_ms   = 1e-3 * ctx->t_eval_us;
-    data.n_p_eval    = std::max(1, ctx->n_p_eval);
-    data.n_eval      = std::max(1, ctx->n_eval);
-
-    return data;
-}
-
-void llama_perf_context_print(const struct llama_context * ctx) {
-    const auto data = llama_perf_context(ctx);
-
-    const double t_end_ms = 1e-3 * ggml_time_us();
-
-    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, data.t_load_ms);
-    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
-    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
-    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
-}
-
-void llama_perf_context_reset(struct llama_context * ctx) {
-    ctx->t_start_us  = ggml_time_us();
-    ctx->t_eval_us   = ctx->n_eval = 0;
-    ctx->t_p_eval_us = ctx->n_p_eval = 0;
-}
diff --git a/llama/llama.cpp/src/unicode.cpp b/llama/llama.cpp/src/unicode.cpp
index 9dd53b9ad..73cb2b1a2 100644
--- a/llama/llama.cpp/src/unicode.cpp
+++ b/llama/llama.cpp/src/unicode.cpp
@@ -220,7 +220,6 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
     free(wbuf);
     return ret;
 #else
-
 #if defined(__clang__)
     // disable C++17 deprecation warning for std::codecvt_utf8
 #    pragma clang diagnostic push
diff --git a/llama/llama.go b/llama/llama.go
index e8cdafe7f..1c2329ce4 100644
--- a/llama/llama.go
+++ b/llama/llama.go
@@ -147,27 +147,27 @@ func (c *Context) Model() *Model {
 }
 
 func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
-	C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
+	C.llama_kv_self_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
 }
 
 func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
-	return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
+	return bool(C.llama_kv_self_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
 }
 
 func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
-	C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
+	C.llama_kv_self_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
 }
 
 func (c *Context) KvCacheClear() {
-	C.llama_kv_cache_clear(c.c)
+	C.llama_kv_self_clear(c.c)
 }
 
 func (c *Context) KvCacheDefrag() {
-	C.llama_kv_cache_defrag(c.c)
+	C.llama_kv_self_defrag(c.c)
 }
 
 func (c *Context) KvCacheCanShift() bool {
-	return bool(C.llama_kv_cache_can_shift(c.c))
+	return bool(C.llama_kv_self_can_shift(c.c))
 }
 
 // Get the embeddings for a sequence id
diff --git a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch
index bef91881b..690f2e33c 100644
--- a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch
+++ b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch
@@ -24,10 +24,10 @@ problem.
  9 files changed, 21 insertions(+), 2 deletions(-)
 
 diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
-index dba7be33..65e150d6 100644
+index 273075f4..dd11f304 100644
 --- a/ggml/src/ggml-backend.cpp
 +++ b/ggml/src/ggml-backend.cpp
-@@ -106,7 +106,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
+@@ -107,7 +107,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
      if (buffer->iface.free_buffer != NULL) {
          buffer->iface.free_buffer(buffer);
      }
@@ -35,7 +35,7 @@ index dba7be33..65e150d6 100644
  }
  
  size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
-@@ -542,6 +541,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer)
+@@ -544,6 +543,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer)
  
      free(ctx->buffers);
      free(ctx);
@@ -43,7 +43,7 @@ index dba7be33..65e150d6 100644
  }
  
  static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-@@ -1865,6 +1865,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
+@@ -1867,6 +1867,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
  
  static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_aligned_free(buffer->context, buffer->size);
@@ -55,7 +55,7 @@ index dba7be33..65e150d6 100644
  }
  
  static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
-@@ -1912,7 +1917,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
+@@ -1914,7 +1919,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
  };
  
  static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
@@ -65,7 +65,7 @@ index dba7be33..65e150d6 100644
      /* .init_tensor     = */ NULL, // no initialization required
      /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,
 diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
-index d410c024..a207ab1e 100644
+index cec36b36..4b057973 100644
 --- a/ggml/src/ggml-cann/ggml-cann.cpp
 +++ b/ggml/src/ggml-cann/ggml-cann.cpp
 @@ -530,6 +530,7 @@ static void ggml_backend_cann_buffer_free_buffer(
@@ -76,7 +76,7 @@ index d410c024..a207ab1e 100644
  }
  
  /**
-@@ -1198,6 +1199,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
+@@ -1199,6 +1200,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
   */
  static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
      ACL_CHECK(aclrtFreeHost(buffer->context));
@@ -85,10 +85,10 @@ index d410c024..a207ab1e 100644
  
  /**
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index ebb2ccae..dfff21a2 100644
+index fafe9633..59a49560 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -529,6 +529,7 @@ struct ggml_backend_cuda_buffer_context {
+@@ -533,6 +533,7 @@ struct ggml_backend_cuda_buffer_context {
  static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
      delete ctx;
@@ -96,7 +96,7 @@ index ebb2ccae..dfff21a2 100644
  }
  
  static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
-@@ -783,6 +784,7 @@ struct ggml_backend_cuda_split_buffer_context {
+@@ -788,6 +789,7 @@ struct ggml_backend_cuda_split_buffer_context {
  static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
      delete ctx;
@@ -104,7 +104,7 @@ index ebb2ccae..dfff21a2 100644
  }
  
  static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
-@@ -1055,6 +1057,7 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
+@@ -1061,6 +1063,7 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
  
  static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      CUDA_CHECK(cudaFreeHost(buffer->context));
@@ -125,10 +125,10 @@ index 50579227..2799a0a5 100644
  
  static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
-index c550142a..fd9a4e77 100644
+index 9f1c6c6c..310afe8a 100644
 --- a/ggml/src/ggml-metal/ggml-metal.m
 +++ b/ggml/src/ggml-metal/ggml-metal.m
-@@ -4350,6 +4350,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
+@@ -4641,6 +4641,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
      }
  
      free(ctx);
@@ -137,10 +137,10 @@ index c550142a..fd9a4e77 100644
  
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
-index f5906246..062e93b8 100644
+index b8b5cbd3..14d4561b 100644
 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp
 +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
-@@ -1203,6 +1203,7 @@ static void * const cl_ptr_base = (void *)(uintptr_t) 0x1000;
+@@ -1443,6 +1443,7 @@ struct ggml_backend_opencl_buffer_context {
  static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
      delete ctx;
@@ -149,10 +149,10 @@ index f5906246..062e93b8 100644
  
  static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
-index 97873acc..893ee0b9 100644
+index 862b9b66..34536681 100644
 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp
 +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
-@@ -419,6 +419,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+@@ -443,6 +443,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
      GGML_ASSERT(status);
      delete ctx;
@@ -161,10 +161,10 @@ index 97873acc..893ee0b9 100644
  
  static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
-index 792e0569..5e233e8b 100644
+index 3e48a924..a3d182fc 100644
 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp
 +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
-@@ -311,6 +311,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
+@@ -316,6 +316,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
      ggml_sycl_set_device(ctx->device);
  
      delete ctx;
@@ -172,7 +172,7 @@ index 792e0569..5e233e8b 100644
  }
  catch (sycl::exception const &exc) {
    std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-@@ -720,6 +721,7 @@ struct ggml_backend_sycl_split_buffer_context {
+@@ -761,6 +762,7 @@ struct ggml_backend_sycl_split_buffer_context {
  static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
      delete ctx;
@@ -180,7 +180,7 @@ index 792e0569..5e233e8b 100644
  }
  
  static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
-@@ -1053,6 +1055,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
+@@ -1095,6 +1097,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
  
  static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_sycl_host_free(buffer->context);
@@ -189,10 +189,10 @@ index 792e0569..5e233e8b 100644
  
  static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
 diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-index abe3e790..1dad714b 100644
+index 783a0ff8..8ac1e07e 100644
 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
 +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-@@ -7914,6 +7914,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+@@ -8639,6 +8639,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
      ggml_vk_destroy_buffer(ctx->dev_buffer);
      delete ctx;
@@ -200,7 +200,7 @@ index abe3e790..1dad714b 100644
  }
  
  static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
-@@ -8056,6 +8057,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
+@@ -8782,6 +8783,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
  static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
      ggml_vk_host_free(vk_instance.devices[0], buffer->context);
diff --git a/llama/patches/0002-pretokenizer.patch b/llama/patches/0002-pretokenizer.patch
index 93a5ce594..02790a51e 100644
--- a/llama/patches/0002-pretokenizer.patch
+++ b/llama/patches/0002-pretokenizer.patch
@@ -3,15 +3,17 @@ From: Michael Yang 
 Date: Mon, 16 Sep 2024 15:53:13 -0700
 Subject: [PATCH] pretokenizer
 
+allow for an unset pretokenizer with a warning in the
+logs instead of throwing an error
 ---
  src/llama-vocab.cpp | 14 +++-----------
  1 file changed, 3 insertions(+), 11 deletions(-)
 
 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index ad9ffe66..a4eee9b8 100644
+index 464ff01e..0125ee53 100644
 --- a/src/llama-vocab.cpp
 +++ b/src/llama-vocab.cpp
-@@ -1468,16 +1468,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+@@ -1491,16 +1491,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
          if (type == LLAMA_VOCAB_TYPE_BPE) {
              add_space_prefix = false;
              clean_spaces = true;
@@ -29,9 +31,9 @@ index ad9ffe66..a4eee9b8 100644
                  pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              } else if (
                      tokenizer_pre == "llama3"   ||
-@@ -1593,7 +1584,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
-                 tokenizer_pre == "megrez") {
-                 pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+@@ -1634,7 +1625,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+                 pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
+                 clean_spaces = false;
              } else {
 -                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
 +                LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
diff --git a/llama/patches/0003-embeddings.patch b/llama/patches/0003-embeddings.patch
index 176cb41da..e8253e8e2 100644
--- a/llama/patches/0003-embeddings.patch
+++ b/llama/patches/0003-embeddings.patch
@@ -1,52 +1,43 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Michael Yang 
-Date: Mon, 16 Sep 2024 15:53:14 -0700
+From: jmorganca 
+Date: Tue, 8 Apr 2025 15:28:34 -0700
 Subject: [PATCH] embeddings
 
+allow a loaded model in llama.cpp to be used for
+both embeddings and causal attention text generation
+instead of forcing one or the error
 ---
- src/llama-context.cpp | 2 +-
- src/llama.cpp         | 6 ++++--
- 2 files changed, 5 insertions(+), 3 deletions(-)
+ src/llama-context.cpp | 6 +++---
+ 1 file changed, 3 insertions(+), 3 deletions(-)
 
 diff --git a/src/llama-context.cpp b/src/llama-context.cpp
-index 671d2a81..47e79ed4 100644
+index 4735e98e..65135172 100644
 --- a/src/llama-context.cpp
 +++ b/src/llama-context.cpp
-@@ -479,7 +479,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
+@@ -1232,7 +1232,7 @@ int llama_context::decode(llama_batch & inp_batch) {
+     int64_t n_outputs_all = 0;
+ 
+     // count outputs
+-    if (batch.logits && !embd_pooled) {
++    if (batch.logits) {
+         for (uint32_t i = 0; i < n_tokens_all; ++i) {
+             n_outputs_all += batch.logits[i] != 0;
+         }
+@@ -1344,7 +1344,7 @@ int llama_context::decode(llama_batch & inp_batch) {
+         //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
+         //}
+ 
+-        auto * t_logits = cparams.embeddings ? nullptr         : res->get_logits();
++        auto * t_logits = cparams.causal_attn ? res->get_logits() : nullptr;
+         auto * t_embd   = cparams.embeddings ? res->get_embd() : nullptr;
+ 
+         if (t_embd && res->get_embd_pooled()) {
+@@ -1488,7 +1488,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
      const auto n_embd  = hparams.n_embd;
  
      // TODO: use a per-batch flag for logits presence instead
--    const bool has_logits = !cparams.embeddings;
-+    const bool has_logits =  cparams.causal_attn;
-     const bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+-    bool has_logits = !cparams.embeddings;
++    bool has_logits =  cparams.causal_attn;
+     bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
  
-     const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
-diff --git a/src/llama.cpp b/src/llama.cpp
-index 607f2786..ac85bfed 100644
---- a/src/llama.cpp
-+++ b/src/llama.cpp
-@@ -8652,7 +8652,6 @@ static int llama_decode_impl(
-             res  = nullptr;
-             embd = nullptr;
-         } else if (cparams.embeddings) {
--            res  = nullptr; // do not extract logits for embedding case
-             embd = nullptr;
-             for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
-                 if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
-@@ -8660,12 +8659,15 @@ static int llama_decode_impl(
-                     break;
-                 }
-             }
--            GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
-         } else {
-             embd = nullptr; // do not extract embeddings when not needed
-             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
-         }
- 
-+        if (!cparams.causal_attn) {
-+            res = nullptr; // do not extract logits when not needed
-+        }
-+
-         // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
- 
-         ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
+     // TODO: hacky enc-dec support
diff --git a/llama/patches/0004-clip-unicode.patch b/llama/patches/0004-clip-unicode.patch
index 50a12aadd..adf30d0f5 100644
--- a/llama/patches/0004-clip-unicode.patch
+++ b/llama/patches/0004-clip-unicode.patch
@@ -1,19 +1,21 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Michael Yang 
-Date: Mon, 16 Sep 2024 15:53:15 -0700
+From: jmorganca 
+Date: Tue, 8 Apr 2025 15:34:37 -0700
 Subject: [PATCH] clip-unicode
 
+fixes loading vision models in llama.cpp on windows
+filesystems for paths that include wide characters
 ---
- examples/llava/clip.cpp | 40 +++++++++++++++++++++++++++++++++++++++-
- 1 file changed, 39 insertions(+), 1 deletion(-)
+ examples/llava/clip.cpp | 39 +++++++++++++++++++++++++++++++++++++++
+ 1 file changed, 39 insertions(+)
 
 diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
-index 76d4a785..205af1eb 100644
+index 49c90b75..4b72ea9f 100644
 --- a/examples/llava/clip.cpp
 +++ b/examples/llava/clip.cpp
-@@ -58,6 +58,19 @@
- #   define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
- #endif // defined(LLAVA_LOG_OFF)
+@@ -28,6 +28,19 @@
+ #include 
+ #include 
  
 +#if defined(_WIN32)
 +#define WIN32_LEAN_AND_MEAN
@@ -28,49 +30,48 @@ index 76d4a785..205af1eb 100644
 +#endif
 +#endif
 +
+ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
+ 
  //#define CLIP_DEBUG_FUNCTIONS
+@@ -1429,7 +1442,29 @@ struct clip_model_loader {
+         {
+             std::vector read_buf;
  
- // RGB uint8 image
-@@ -1402,8 +1415,29 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
-             gguf_free(ctx);
-             return nullptr;
-         }
--
 +#ifdef _WIN32
-+        int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0);
-+        if (!wlen) {
-+            return NULL;
-+        }
-+        wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
-+        wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wbuf, wlen);
-+        if (!wlen) {
-+            free(wbuf);
-+            return NULL;
-+        }
++            int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0);
++            if (!wlen) {
++                throw std::runtime_error(string_format("%s: failed to convert filename to wide string\n", __func__));
++            }
++            wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
++            wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wbuf, wlen);
++            if (!wlen) {
++                free(wbuf);
++                throw std::runtime_error(string_format("%s: failed to convert filename to wide string\n", __func__));
++            }
 +#if __GLIBCXX__
-+        int fd = _wopen(wbuf, _O_RDONLY | _O_BINARY);
-+        __gnu_cxx::stdio_filebuf buffer(fd, std::ios_base::in);
-+        std::istream fin(&buffer);
++            int fd = _wopen(wbuf, _O_RDONLY | _O_BINARY);
++            __gnu_cxx::stdio_filebuf buffer(fd, std::ios_base::in);
++            std::istream fin(&buffer);
 +#else // MSVC
-+        // unused in our current build
-+        auto fin = std::ifstream(wbuf, std::ios::binary);
++            // unused in our current build
++            auto fin = std::ifstream(wbuf, std::ios::binary);
 +#endif
-+        free(wbuf);
++            free(wbuf);
 +#else
-         auto fin = std::ifstream(fname, std::ios::binary);
+             auto fin = std::ifstream(fname, std::ios::binary);
 +#endif
-         if (!fin) {
-             LOG_ERR("cannot open model file for loading tensors\n");
-             clip_free(new_clip);
-@@ -1443,7 +1477,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
-                 ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
+             if (!fin) {
+                 throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
+             }
+@@ -1456,7 +1491,11 @@ struct clip_model_loader {
+                     ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
+                 }
              }
-         }
 +#if defined(_WIN32) && defined(__GLIBCXX__)
-+        close(fd);
++            close(fd);
 +#else
-         fin.close();
+             fin.close();
 +#endif
-     }
  
-     // vision model
+             LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
+         }
diff --git a/llama/patches/0005-solar-pro.patch b/llama/patches/0005-solar-pro.patch
index e201682bf..770464d15 100644
--- a/llama/patches/0005-solar-pro.patch
+++ b/llama/patches/0005-solar-pro.patch
@@ -1,47 +1,40 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Michael Yang 
-Date: Mon, 16 Sep 2024 15:53:16 -0700
+From: jmorganca 
+Date: Tue, 8 Apr 2025 16:03:51 -0700
 Subject: [PATCH] solar-pro
 
-solar-pro introduces block skip connections where blocks are connected
-to other, non-sequential blocks with a scale multiple
-
-this change adds 4 new keys to store the skip connections and one new
-tensor to store the scalar. the scalar is implemented a 1-dimensional
-tensor with 2 elements dervied from the model's bskcn_tv configuration.
-in general, the values are (bskcn_tv, 1 - bskcn_tv)
+adds support for the Solar Pro architecture
 ---
- src/llama-arch.cpp         |  21 +++++
+ src/llama-arch.cpp         |  21 ++++
  src/llama-arch.h           |   3 +
  src/llama-hparams.cpp      |   8 ++
- src/llama-hparams.h        |   5 ++
+ src/llama-hparams.h        |   5 +
  src/llama-model-loader.cpp |   1 +
- src/llama-model.cpp        |  44 +++++++++++
+ src/llama-model.cpp        | 207 +++++++++++++++++++++++++++++++++++++
  src/llama-model.h          |   3 +
- src/llama.cpp              | 152 ++++++++++++++++++++++++++++++++++++-
- 8 files changed, 236 insertions(+), 1 deletion(-)
+ 7 files changed, 248 insertions(+)
 
 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index 97a1e7e5..a1e0ebcc 100644
+index a6fddc7f..0b0fedcd 100644
 --- a/src/llama-arch.cpp
 +++ b/src/llama-arch.cpp
-@@ -61,6 +61,7 @@ static const std::map LLM_ARCH_NAMES = {
+@@ -68,6 +68,7 @@ static const std::map LLM_ARCH_NAMES = {
      { LLM_ARCH_GRANITE,          "granite"          },
      { LLM_ARCH_GRANITE_MOE,      "granitemoe"       },
      { LLM_ARCH_CHAMELEON,        "chameleon"        },
 +    { LLM_ARCH_SOLAR,            "solar"            },
      { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
-     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
- };
-@@ -125,6 +126,7 @@ static const std::map LLM_KV_NAMES = {
-     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
-     { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
-     { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
-+    { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,  "%s.attention.block_skip_connection"  },
+     { LLM_ARCH_PLM,              "plm"              },
+     { LLM_ARCH_BAILINGMOE,       "bailingmoe"       },
+@@ -140,6 +141,7 @@ static const std::map LLM_KV_NAMES = {
+     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,       "%s.attention.relative_buckets_count"       },
+     { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
+     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
++    { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,        "%s.attention.block_skip_connection"        },
  
      { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
      { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
-@@ -1271,6 +1273,24 @@ static const std::map> LLM_TENSOR_N
+@@ -1478,6 +1480,24 @@ static const std::map> LLM_TENSOR_N
              { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
          },
      },
@@ -66,7 +59,7 @@ index 97a1e7e5..a1e0ebcc 100644
      {
          LLM_ARCH_WAVTOKENIZER_DEC,
          {
-@@ -1429,6 +1449,7 @@ static const std::map LLM_TENSOR_INFOS = {
+@@ -1671,6 +1691,7 @@ static const std::map LLM_TENSOR_INFOS = {
      {LLM_TENSOR_FFN_EXP_PROBS_B,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
      // this tensor is loaded for T5, but never used
      {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
@@ -75,18 +68,18 @@ index 97a1e7e5..a1e0ebcc 100644
      {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
      {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 diff --git a/src/llama-arch.h b/src/llama-arch.h
-index 122fdceb..77919578 100644
+index 2c2099b3..74aa3dd0 100644
 --- a/src/llama-arch.h
 +++ b/src/llama-arch.h
-@@ -65,6 +65,7 @@ enum llm_arch {
+@@ -72,6 +72,7 @@ enum llm_arch {
      LLM_ARCH_GRANITE,
      LLM_ARCH_GRANITE_MOE,
      LLM_ARCH_CHAMELEON,
 +    LLM_ARCH_SOLAR,
      LLM_ARCH_WAVTOKENIZER_DEC,
-     LLM_ARCH_UNKNOWN,
- };
-@@ -129,6 +130,7 @@ enum llm_kv {
+     LLM_ARCH_PLM,
+     LLM_ARCH_BAILINGMOE,
+@@ -144,6 +145,7 @@ enum llm_kv {
      LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
      LLM_KV_ATTENTION_SLIDING_WINDOW,
      LLM_KV_ATTENTION_SCALE,
@@ -94,7 +87,7 @@ index 122fdceb..77919578 100644
  
      LLM_KV_ROPE_DIMENSION_COUNT,
      LLM_KV_ROPE_DIMENSION_SECTIONS,
-@@ -311,6 +313,7 @@ enum llm_tensor {
+@@ -340,6 +342,7 @@ enum llm_tensor {
      LLM_TENSOR_ENC_OUTPUT_NORM,
      LLM_TENSOR_CLS,
      LLM_TENSOR_CLS_OUT,
@@ -103,14 +96,13 @@ index 122fdceb..77919578 100644
      LLM_TENSOR_CONVNEXT_DW,
      LLM_TENSOR_CONVNEXT_NORM,
 diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
-index ea87b295..f3955de9 100644
+index 90dfe7a7..8a667960 100644
 --- a/src/llama-hparams.cpp
 +++ b/src/llama-hparams.cpp
-@@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
-     // corresponds to Mamba's ssm_states size
+@@ -70,6 +70,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
      return ssm_d_state * ssm_d_inner;
  }
-+
+ 
 +bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const {
 +    if (il < n_layer) {
 +        return n_bskcn_arr[n][il] > 0;
@@ -118,12 +110,15 @@ index ea87b295..f3955de9 100644
 +
 +    GGML_ABORT("fatal error");
 +}
-\ No newline at end of file
++
+ bool llama_hparams::is_swa(uint32_t il) const {
+     if (il < n_layer) {
+         return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
 diff --git a/src/llama-hparams.h b/src/llama-hparams.h
-index 1fe45410..1bdcdfd5 100644
+index 4e0b5719..c3147cbc 100644
 --- a/src/llama-hparams.h
 +++ b/src/llama-hparams.h
-@@ -50,6 +50,8 @@ struct llama_hparams {
+@@ -51,6 +51,8 @@ struct llama_hparams {
      std::array n_head_kv_arr;
      std::array n_ff_arr;
  
@@ -132,18 +127,18 @@ index 1fe45410..1bdcdfd5 100644
      uint32_t n_layer_dense_lead = 0;
      uint32_t n_lora_q           = 0;
      uint32_t n_lora_kv          = 0;
-@@ -133,6 +135,9 @@ struct llama_hparams {
- 
+@@ -149,6 +151,9 @@ struct llama_hparams {
      // dimension of the recurrent state embeddings
      uint32_t n_embd_v_s() const;
-+
+ 
 +    // Block skip connection
 +    bool n_bskcn(uint32_t n, uint32_t il) const;
++
+     bool is_swa(uint32_t il) const;
  };
  
- static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable");
 diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
-index 05d58ad9..1252aca1 100644
+index ea73a8a7..a012aeae 100644
 --- a/src/llama-model-loader.cpp
 +++ b/src/llama-model-loader.cpp
 @@ -439,6 +439,7 @@ namespace GGUFMeta {
@@ -155,10 +150,10 @@ index 05d58ad9..1252aca1 100644
  llama_model_loader::llama_model_loader(
          const std::string & fname,
 diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index 36a0a009..ad1315c6 100644
+index b74dd72c..5fbd0055 100644
 --- a/src/llama-model.cpp
 +++ b/src/llama-model.cpp
-@@ -1238,6 +1238,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+@@ -1372,6 +1372,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                      default: type = LLM_TYPE_UNKNOWN;
                 }
              } break;
@@ -180,7 +175,7 @@ index 36a0a009..ad1315c6 100644
          case LLM_ARCH_WAVTOKENIZER_DEC:
              {
                  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-@@ -3316,6 +3331,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+@@ -3701,6 +3716,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
  
                          layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
  
@@ -215,54 +210,12 @@ index 36a0a009..ad1315c6 100644
                          layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
                          layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                          layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-@@ -3900,6 +3943,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
-         case LLM_ARCH_GRANITE:
-         case LLM_ARCH_GRANITE_MOE:
-         case LLM_ARCH_CHAMELEON:
-+        case LLM_ARCH_SOLAR:
-             return LLAMA_ROPE_TYPE_NORM;
+@@ -12244,6 +12287,165 @@ struct llm_build_chameleon : public llm_graph_context {
+     }
+ };
  
-         // the pairs of head values are offset by n_rot/2
-diff --git a/src/llama-model.h b/src/llama-model.h
-index a7c30444..1afb0024 100644
---- a/src/llama-model.h
-+++ b/src/llama-model.h
-@@ -55,6 +55,7 @@ enum llm_type {
-     LLM_TYPE_15B,
-     LLM_TYPE_16B,
-     LLM_TYPE_20B,
-+    LLM_TYPE_22B,
-     LLM_TYPE_30B,
-     LLM_TYPE_32B,
-     LLM_TYPE_34B,
-@@ -281,6 +282,8 @@ struct llama_layer {
-     struct ggml_tensor * ffn_up_scale   = nullptr;
-     struct ggml_tensor * ffn_down_scale = nullptr;
- 
-+    struct ggml_tensor * bskcn_tv = nullptr;
-+
-     struct llama_layer_posnet posnet;
- 
-     struct llama_layer_convnext convnext;
-diff --git a/src/llama.cpp b/src/llama.cpp
-index ac85bfed..6d320ea4 100644
---- a/src/llama.cpp
-+++ b/src/llama.cpp
-@@ -7953,9 +7953,155 @@ struct llm_build_context {
-         cb(img_logits, "img_logits", -1);
-         cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
-         cb(cur, "result_output", -1);
--
-         ggml_build_forward_expand(gf, cur);
-+        return gf;
-+   }
-+
-+   ggml_cgraph * build_solar() {
-+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-+
-+        // mutable variable, needed during the last layer of the computation to skip unused tokens
-+        int32_t n_tokens = this->n_tokens;
-+
++struct llm_build_solar : public llm_graph_context {
++    llm_build_solar(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
 +        const int64_t n_embd_head = hparams.n_embd_head_v;
 +        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 +        GGML_ASSERT(n_embd_head == hparams.n_rot);
@@ -270,13 +223,15 @@ index ac85bfed..6d320ea4 100644
 +        struct ggml_tensor * cur;
 +        struct ggml_tensor * inpL;
 +
-+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
++        inpL = build_inp_embd(model.tok_embd);
 +
 +        // inp_pos - contains the positions
 +        struct ggml_tensor * inp_pos = build_inp_pos();
 +
 +        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
++        auto * inp_attn = build_attn_inp_kv_unified();
++
++        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 +
 +        struct ggml_tensor * bskcn_1;
 +        struct ggml_tensor * bskcn_2;
@@ -305,88 +260,94 @@ index ac85bfed..6d320ea4 100644
 +                   ggml_mul(ctx0, bskcn_2, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
 +                   ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
 +            }
- 
++
 +            // norm
-+            cur = llm_build_norm(ctx0, inpL, hparams,
++            cur = build_norm(inpL,
 +                    model.layers[il].attn_norm, NULL,
-+                    LLM_NORM_RMS, cb, il);
++                    LLM_NORM_RMS, il);
 +            cb(cur, "attn_norm", il);
 +
 +            // self-attention
 +            {
 +                // rope freq factors for llama3; may return nullptr for llama2 and other models
-+                struct ggml_tensor * rope_factors = build_rope_factors(il);
++                ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
 +
 +                // compute Q and K and RoPE them
-+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
++                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 +                cb(Qcur, "Qcur", il);
 +                if (model.layers[il].bq) {
 +                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
 +                    cb(Qcur, "Qcur", il);
 +                }
 +
-+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
++                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
 +                cb(Kcur, "Kcur", il);
 +                if (model.layers[il].bk) {
 +                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
 +                    cb(Kcur, "Kcur", il);
 +                }
 +
-+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
++                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
 +                cb(Vcur, "Vcur", il);
 +                if (model.layers[il].bv) {
 +                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
 +                    cb(Vcur, "Vcur", il);
 +                }
 +
++                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
++                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
++                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
++
 +                Qcur = ggml_rope_ext(
-+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-+                    ext_factor, attn_factor, beta_fast, beta_slow
-+                );
-+                cb(Qcur, "Qcur", il);
++                        ctx0, Qcur, inp_pos, rope_factors,
++                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
++                        ext_factor, attn_factor, beta_fast, beta_slow
++                        );
 +
 +                Kcur = ggml_rope_ext(
-+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-+                    ext_factor, attn_factor, beta_fast, beta_slow
-+                );
++                        ctx0, Kcur, inp_pos, rope_factors,
++                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
++                        ext_factor, attn_factor, beta_fast, beta_slow
++                        );
++                
++                cb(Qcur, "Qcur", il);
 +                cb(Kcur, "Kcur", il);
++                cb(Vcur, "Vcur", il);
 +
-+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
++                cur = build_attn(inp_attn, gf,
 +                        model.layers[il].wo, model.layers[il].bo,
-+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
++                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
++                cb(cur, "attn_out", il);
 +            }
 +
 +            if (il == n_layer - 1) {
 +                // skip computing output for unused tokens
-+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-+                n_tokens = n_outputs;
++                ggml_tensor * inp_out_ids = build_inp_out_ids();
 +                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
 +                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
 +            }
 +
-+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
++            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
 +            cb(ffn_inp, "ffn_inp", il);
 +
 +            // feed-forward network
-+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
++            cur = build_norm(ffn_inp,
 +                    model.layers[il].ffn_norm, NULL,
-+                    LLM_NORM_RMS, cb, il);
++                    LLM_NORM_RMS, il);
 +            cb(cur, "ffn_norm", il);
 +
-+            cur = llm_build_ffn(ctx0, lctx, cur,
++            cur = build_ffn(cur,
 +                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
 +                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
 +                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
 +                    NULL,
-+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
++                    LLM_FFN_SILU, LLM_FFN_PAR, il);
 +            cb(cur, "ffn_out", il);
 +
 +            cur = ggml_add(ctx0, cur, ffn_inp);
 +            cb(cur, "ffn_out", il);
 +
-+            cur = lctx.cvec.apply_to(ctx0, cur, il);
++            cur = build_cvec(cur, il);
 +            cb(cur, "l_out", il);
 +
 +            // input for next layer
@@ -394,25 +355,64 @@ index ac85bfed..6d320ea4 100644
 +        }
 +
 +        cur = inpL;
-+        cur = llm_build_norm(ctx0, cur, hparams,
++
++        cur = build_norm(cur,
 +                model.output_norm, NULL,
-+                LLM_NORM_RMS, cb, -1);
++                LLM_NORM_RMS, -1);
++
 +        cb(cur, "result_norm", -1);
++        res->t_embd = cur;
++
 +        // lm_head
-+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
++        cur = build_lora_mm(model.output, cur);
++
 +        cb(cur, "result_output", -1);
++        res->t_logits = cur;
++
 +        ggml_build_forward_expand(gf, cur);
-         return gf;
-     }
- 
-@@ -8398,6 +8544,10 @@ static struct ggml_cgraph * llama_build_graph(
++    }
++};
++
+ struct llm_build_wavtokenizer_dec : public llm_graph_context {
+     llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+         ggml_tensor * cur;
+@@ -12993,6 +13195,10 @@ llm_graph_result_ptr llama_model::build_graph(
              {
-                 result = llm.build_chameleon();
+                 llm = std::make_unique(*this, params, gf);
              } break;
 +        case LLM_ARCH_SOLAR:
 +            {
-+                result = llm.build_solar();
++                llm = std::make_unique(*this, params, gf);
 +            } break;
          case LLM_ARCH_WAVTOKENIZER_DEC:
              {
-                 result = llm.build_wavtokenizer_dec();
+                 llm = std::make_unique(*this, params, gf);
+@@ -13139,6 +13345,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
+         case LLM_ARCH_GRANITE:
+         case LLM_ARCH_GRANITE_MOE:
+         case LLM_ARCH_CHAMELEON:
++        case LLM_ARCH_SOLAR:
+         case LLM_ARCH_BAILINGMOE:
+             return LLAMA_ROPE_TYPE_NORM;
+ 
+diff --git a/src/llama-model.h b/src/llama-model.h
+index 0f18dac1..e08d4ae4 100644
+--- a/src/llama-model.h
++++ b/src/llama-model.h
+@@ -62,6 +62,7 @@ enum llm_type {
+     LLM_TYPE_15B,
+     LLM_TYPE_16B,
+     LLM_TYPE_20B,
++    LLM_TYPE_22B,
+     LLM_TYPE_30B,
+     LLM_TYPE_32B,
+     LLM_TYPE_34B,
+@@ -305,6 +306,8 @@ struct llama_layer {
+     struct ggml_tensor * ffn_up_scale   = nullptr;
+     struct ggml_tensor * ffn_down_scale = nullptr;
+ 
++    struct ggml_tensor * bskcn_tv = nullptr;
++
+     struct llama_layer_posnet posnet;
+ 
+     struct llama_layer_convnext convnext;
diff --git a/llama/patches/0006-conditional-fattn.patch b/llama/patches/0006-conditional-fattn.patch
index 97af82e20..5be2b81d0 100644
--- a/llama/patches/0006-conditional-fattn.patch
+++ b/llama/patches/0006-conditional-fattn.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] conditional-fattn
  1 file changed, 2 insertions(+)
 
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index dfff21a2..1b0d074b 100644
+index 59a49560..b70c6a32 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -2284,9 +2284,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
+@@ -2338,9 +2338,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
          case GGML_OP_ARGSORT:
              ggml_cuda_op_argsort(ctx, dst);
              break;
diff --git a/llama/patches/0007-add-mllama-support.patch b/llama/patches/0007-add-mllama-support.patch
index efed0923b..6d7e91d7e 100644
--- a/llama/patches/0007-add-mllama-support.patch
+++ b/llama/patches/0007-add-mllama-support.patch
@@ -1,41 +1,58 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
 From: jmorganca 
-Date: Thu, 17 Oct 2024 15:18:22 -0700
+Date: Tue, 8 Apr 2025 19:27:12 -0700
 Subject: [PATCH] add mllama support
 
-mllama adds cross-attention layers to the standard llama architecture
-it also requires a way to input a new tensor: cross_attention_state
-once per generation
-
-cross-attention layers don't change and so they are cached in the
-kv cache once per run
-
-remaining is to implement the cross attention mask
+adds support for the llama 3.2 vision architecture
 ---
+ examples/llava/gemma3-cli.cpp |   3 +-
  examples/llava/llava.cpp      |   5 +-
+ examples/llava/mtmd.cpp       |   6 +-
  ggml/src/ggml-backend-reg.cpp |   6 +-
  include/llama.h               |   6 +
- src/llama-arch.cpp            |  44 ++++++
+ src/llama-arch.cpp            |  44 +++++
  src/llama-arch.h              |  10 ++
  src/llama-batch.cpp           |   3 +
- src/llama-context.cpp         |  28 ++--
- src/llama-context.h           |   2 +
+ src/llama-context.cpp         |  25 ++-
+ src/llama-context.h           |   1 +
  src/llama-cparams.h           |   1 +
- src/llama-hparams.cpp         |   6 +
- src/llama-hparams.h           |   5 +
- src/llama-kv-cache.cpp        |  13 +-
+ src/llama-graph.cpp           |  25 +++
+ src/llama-graph.h             |  12 ++
+ src/llama-hparams.cpp         |   4 +
+ src/llama-hparams.h           |   7 +
+ src/llama-kv-cache.cpp        |  12 +-
  src/llama-model-loader.cpp    |   2 +
- src/llama-model.cpp           |  65 ++++++++-
+ src/llama-model.cpp           | 309 +++++++++++++++++++++++++++++++++-
  src/llama-model.h             |  12 ++
  src/llama-quant.cpp           |   4 +-
- src/llama.cpp                 | 262 +++++++++++++++++++++++++++++++++-
- 17 files changed, 452 insertions(+), 22 deletions(-)
+ 20 files changed, 475 insertions(+), 22 deletions(-)
 
+diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp
+index 91a07e2a..13127c7b 100644
+--- a/examples/llava/gemma3-cli.cpp
++++ b/examples/llava/gemma3-cli.cpp
+@@ -106,7 +106,7 @@ struct decode_embd_batch {
+     std::vector seq_ids;
+     std::vector         logits;
+     llama_batch batch;
+-    decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
++    decode_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
+         pos     .resize(n_tokens);
+         n_seq_id.resize(n_tokens);
+         seq_ids .resize(n_tokens + 1);
+@@ -118,6 +118,7 @@ struct decode_embd_batch {
+             /*n_tokens       =*/ n_tokens,
+             /*tokens         =*/ nullptr,
+             /*embd           =*/ embd,
++            /*n_embd         =*/ n_embd,
+             /*pos            =*/ pos.data(),
+             /*n_seq_id       =*/ n_seq_id.data(),
+             /*seq_id         =*/ seq_ids.data(),
 diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
-index 518aad3f..f0e484a1 100644
+index 03a22cbb..5eb40bcd 100644
 --- a/examples/llava/llava.cpp
 +++ b/examples/llava/llava.cpp
-@@ -445,7 +445,7 @@ struct llava_embd_batch {
+@@ -456,7 +456,7 @@ struct llava_embd_batch {
      std::vector seq_ids;
      std::vector         logits;
      llama_batch batch;
@@ -44,7 +61,7 @@ index 518aad3f..f0e484a1 100644
          pos     .resize(n_tokens);
          n_seq_id.resize(n_tokens);
          seq_ids .resize(n_tokens + 1);
-@@ -457,6 +457,7 @@ struct llava_embd_batch {
+@@ -468,6 +468,7 @@ struct llava_embd_batch {
              /*n_tokens       =*/ n_tokens,
              /*tokens         =*/ nullptr,
              /*embd           =*/ embd,
@@ -52,7 +69,7 @@ index 518aad3f..f0e484a1 100644
              /*pos            =*/ pos.data(),
              /*n_seq_id       =*/ n_seq_id.data(),
              /*seq_id         =*/ seq_ids.data(),
-@@ -480,7 +481,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
+@@ -491,7 +492,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
              n_eval = n_batch;
          }
          float * embd = image_embed->embed+i*n_embd;
@@ -61,11 +78,42 @@ index 518aad3f..f0e484a1 100644
          if (llama_decode(ctx_llama, llava_batch.batch)) {
              LOG_ERR("%s : failed to eval\n", __func__);
              return false;
+diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp
+index 114c274b..a0e649ad 100644
+--- a/examples/llava/mtmd.cpp
++++ b/examples/llava/mtmd.cpp
+@@ -213,7 +213,7 @@ struct decode_embd_batch {
+     std::vector seq_ids;
+     std::vector         logits;
+     llama_batch batch;
+-    decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
++    decode_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
+         pos     .resize(n_tokens);
+         n_seq_id.resize(n_tokens);
+         seq_ids .resize(n_tokens + 1);
+@@ -225,6 +225,7 @@ struct decode_embd_batch {
+             /*n_tokens       =*/ n_tokens,
+             /*tokens         =*/ nullptr,
+             /*embd           =*/ embd,
++            /*n_embd         =*/ n_embd,
+             /*pos            =*/ pos.data(),
+             /*n_seq_id       =*/ n_seq_id.data(),
+             /*seq_id         =*/ seq_ids.data(),
+@@ -291,7 +292,8 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
+ 
+             int32_t n_tokens = chunk.tokens_image->n_tokens();
+             float * embd = mtmd_get_output_embd(ctx);
+-            decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
++            int n_embd  = llama_model_n_embd(llama_get_model(lctx));
++            decode_embd_batch batch_img(embd, n_embd, n_tokens, n_past, 0);
+             int64_t t1 = ggml_time_ms();
+             ret = llama_decode(lctx, batch_img.batch);
+             if (ret != 0) {
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 955ed505..95036ef8 100644
+index 405d8e31..82ae1b5b 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
-@@ -171,9 +171,9 @@ struct ggml_backend_registry {
+@@ -178,9 +178,9 @@ struct ggml_backend_registry {
  #ifdef GGML_USE_CANN
          register_backend(ggml_backend_cann_reg());
  #endif
@@ -79,10 +127,10 @@ index 955ed505..95036ef8 100644
          register_backend(ggml_backend_rpc_reg());
  #endif
 diff --git a/include/llama.h b/include/llama.h
-index 47919602..cc948005 100644
+index 5657fbf0..f91896e4 100644
 --- a/include/llama.h
 +++ b/include/llama.h
-@@ -249,6 +249,7 @@ extern "C" {
+@@ -255,6 +255,7 @@ extern "C" {
  
          llama_token  *  token;
          float        *  embd;
@@ -90,7 +138,7 @@ index 47919602..cc948005 100644
          llama_pos    *  pos;
          int32_t      *  n_seq_id;
          llama_seq_id ** seq_id;
-@@ -343,6 +344,7 @@ extern "C" {
+@@ -357,6 +358,7 @@ extern "C" {
          bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
          bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
          bool no_perf;     // whether to measure performance timings
@@ -98,7 +146,7 @@ index 47919602..cc948005 100644
  
          // Abort callback
          // if it returns true, execution of llama_decode() will be aborted
-@@ -443,6 +445,10 @@ extern "C" {
+@@ -458,6 +460,10 @@ extern "C" {
              struct llama_context_params   params),
              "use llama_init_from_model instead");
  
@@ -110,7 +158,7 @@ index 47919602..cc948005 100644
      LLAMA_API void llama_free(struct llama_context * ctx);
  
 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index a1e0ebcc..b6f20286 100644
+index 0b0fedcd..c1f78618 100644
 --- a/src/llama-arch.cpp
 +++ b/src/llama-arch.cpp
 @@ -6,6 +6,7 @@
@@ -118,19 +166,19 @@ index a1e0ebcc..b6f20286 100644
  static const std::map LLM_ARCH_NAMES = {
      { LLM_ARCH_LLAMA,            "llama"            },
 +    { LLM_ARCH_MLLAMA,           "mllama"           },
+     { LLM_ARCH_LLAMA4,           "llama4"           },
      { LLM_ARCH_DECI,             "deci"             },
      { LLM_ARCH_FALCON,           "falcon"           },
-     { LLM_ARCH_GROK,             "grok"             },
-@@ -127,6 +128,7 @@ static const std::map LLM_KV_NAMES = {
-     { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
-     { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
-     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,  "%s.attention.block_skip_connection"  },
-+    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
+@@ -142,6 +143,7 @@ static const std::map LLM_KV_NAMES = {
+     { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
+     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
+     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,        "%s.attention.block_skip_connection"        },
++    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,       "%s.attention.cross_attention_layers"       },
  
      { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
      { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
-@@ -225,6 +227,40 @@ static const std::map> LLM_TENSOR_N
-             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+@@ -269,6 +271,40 @@ static const std::map> LLM_TENSOR_N
+             { LLM_TENSOR_FFN_UP_SHEXP,    "blk.%d.ffn_up_shexp" },
          },
      },
 +    {
@@ -170,7 +218,7 @@ index a1e0ebcc..b6f20286 100644
      {
          LLM_ARCH_DECI,
          {
-@@ -1450,6 +1486,14 @@ static const std::map LLM_TENSOR_INFOS = {
+@@ -1692,6 +1728,14 @@ static const std::map LLM_TENSOR_INFOS = {
      // this tensor is loaded for T5, but never used
      {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
      {LLM_TENSOR_BSKCN_TV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -186,18 +234,18 @@ index a1e0ebcc..b6f20286 100644
      {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
      {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 diff --git a/src/llama-arch.h b/src/llama-arch.h
-index 77919578..ec742224 100644
+index 74aa3dd0..f987844d 100644
 --- a/src/llama-arch.h
 +++ b/src/llama-arch.h
-@@ -10,6 +10,7 @@
- 
+@@ -11,6 +11,7 @@
  enum llm_arch {
      LLM_ARCH_LLAMA,
+     LLM_ARCH_LLAMA4,
 +    LLM_ARCH_MLLAMA,
      LLM_ARCH_DECI,
      LLM_ARCH_FALCON,
      LLM_ARCH_BAICHUAN,
-@@ -131,6 +132,7 @@ enum llm_kv {
+@@ -146,6 +147,7 @@ enum llm_kv {
      LLM_KV_ATTENTION_SLIDING_WINDOW,
      LLM_KV_ATTENTION_SCALE,
      LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
@@ -205,7 +253,7 @@ index 77919578..ec742224 100644
  
      LLM_KV_ROPE_DIMENSION_COUNT,
      LLM_KV_ROPE_DIMENSION_SECTIONS,
-@@ -314,6 +316,14 @@ enum llm_tensor {
+@@ -343,6 +345,14 @@ enum llm_tensor {
      LLM_TENSOR_CLS,
      LLM_TENSOR_CLS_OUT,
      LLM_TENSOR_BSKCN_TV,
@@ -249,39 +297,66 @@ index 01d5ca57..8682b0e6 100644
          batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
      }
 diff --git a/src/llama-context.cpp b/src/llama-context.cpp
-index 47e79ed4..7b22fe13 100644
+index 65135172..afe6f552 100644
 --- a/src/llama-context.cpp
 +++ b/src/llama-context.cpp
-@@ -74,10 +74,19 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
-     }
+@@ -858,7 +858,7 @@ float * llama_context::get_logits_ith(int32_t i) {
+             throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
+         }
  
-     if (ubatch.embd) {
--        const int64_t n_embd   = hparams.n_embd;
--        const int64_t n_tokens = ubatch.n_tokens;
-+        if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) {
-+            ggml_backend_tensor_set(lctx.inp_cross_attn_state, ubatch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state));
-+            // zero out inp_embd since it's not used
-+            float * inp_embd_data = (float *)lctx.inp_embd->data;
-+            for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) {
-+                inp_embd_data[i] = 0.0f;
-+            }
-+        } else {
-+            const int64_t n_embd   = hparams.n_embd;
-+            const int64_t n_tokens = ubatch.n_tokens;
+-        return logits + j*model.vocab.n_tokens();
++        return logits + j*model.hparams.n_vocab;
+     } catch (const std::exception & err) {
+         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
+ #ifndef NDEBUG
+@@ -979,6 +979,10 @@ void llama_context::set_warmup(bool value) {
+     cparams.warmup = value;
+ }
  
--        ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
-+            ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
-+        }
-     }
++void llama_context::set_cross_attn(bool value) {
++    cparams.cross_attn = value;
++}
++
+ void llama_context::set_adapter_lora(
+             llama_adapter_lora * adapter,
+             float scale) {
+@@ -1054,7 +1058,7 @@ int llama_context::encode(llama_batch & inp_batch) {
  
-     if (ubatch.pos && lctx.inp_pos) {
-@@ -470,12 +479,11 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
- size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
-     const auto & cparams = lctx.cparams;
-     const auto & hparams = lctx.model.hparams;
--    const auto & vocab   = lctx.model.vocab;
+     const int64_t n_embd = hparams.n_embd;
  
-     const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
+-    sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
++    sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
+ 
+     const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
+ 
+@@ -1194,10 +1198,9 @@ int llama_context::decode(llama_batch & inp_batch) {
+ 
+     const llama_batch & batch = batch_allocr.batch;
+ 
+-    const auto & vocab   = model.vocab;
+     const auto & hparams = model.hparams;
+ 
+-    const int32_t n_vocab = vocab.n_tokens();
++    const int32_t n_vocab = hparams.n_vocab;
+ 
+     const int64_t n_tokens_all = batch.n_tokens;
+     const int64_t n_embd       = hparams.n_embd;
+@@ -1245,7 +1248,7 @@ int llama_context::decode(llama_batch & inp_batch) {
+ 
+     const bool logits_all = n_outputs_all == n_tokens_all;
+ 
+-    sbatch.from_batch(batch, n_embd,
++    sbatch.from_batch(batch, batch.n_embd,
+             /* simple_split */ !kv_self->recurrent,
+             /* logits_all   */ logits_all);
+ 
+@@ -1479,12 +1482,11 @@ int llama_context::decode(llama_batch & inp_batch) {
+ 
+ int32_t llama_context::output_reserve(int32_t n_outputs) {
+     const auto & hparams = model.hparams;
+-    const auto & vocab   = model.vocab;
+ 
+     const int64_t n_outputs_max = std::max(n_outputs, n_seq_max());
  
      const auto n_batch = cparams.n_batch;
 -    const auto n_vocab = vocab.n_tokens();
@@ -289,59 +364,57 @@ index 47e79ed4..7b22fe13 100644
      const auto n_embd  = hparams.n_embd;
  
      // TODO: use a per-batch flag for logits presence instead
-@@ -542,7 +550,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
- void llama_output_reorder(struct llama_context & ctx) {
-     std::vector & out_ids = ctx.sbatch.out_ids;
+@@ -1554,7 +1556,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
+ void llama_context::output_reorder() {
+     auto & out_ids = sbatch.out_ids;
      if (!out_ids.empty()) {
--        const uint32_t n_vocab = ctx.model.vocab.n_tokens();
-+        const uint32_t n_vocab = ctx.model.hparams.n_vocab;
-         const uint32_t n_embd  = ctx.model.hparams.n_embd;
+-        const uint32_t n_vocab = model.vocab.n_tokens();
++        const uint32_t n_vocab = model.hparams.n_vocab;
+         const uint32_t n_embd  = model.hparams.n_embd;
  
-         const int32_t n_outputs = ctx.n_outputs;
-@@ -657,6 +665,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
-     ctx->cparams.causal_attn = causal_attn;
+         GGML_ASSERT((size_t) n_outputs == out_ids.size());
+@@ -2061,7 +2063,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
+     {
+         LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
+ 
+-        const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
++        const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.hparams.n_vocab);
+ 
+         io.write(&logits_size, sizeof(logits_size));
+ 
+@@ -2244,6 +2246,7 @@ llama_context_params llama_context_default_params() {
+         /*.offload_kqv                 =*/ true,
+         /*.flash_attn                  =*/ false,
+         /*.no_perf                     =*/ true,
++        /*.cross_attn                  =*/ false,
+         /*.abort_callback              =*/ nullptr,
+         /*.abort_callback_data         =*/ nullptr,
+     };
+@@ -2371,6 +2374,10 @@ void llama_set_warmup(llama_context * ctx, bool warmup) {
+     ctx->set_warmup(warmup);
  }
  
 +void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
-+    ctx->cparams.cross_attn = cross_attention;
++    ctx->set_cross_attn(cross_attention);
 +}
 +
- void llama_synchronize(struct llama_context * ctx) {
-     ggml_backend_sched_synchronize(ctx->sched.get());
- 
-@@ -726,7 +738,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
-             throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
-         }
- 
--        return ctx->logits + j*ctx->model.vocab.n_tokens();
-+        return ctx->logits + j*ctx->model.hparams.n_vocab;
-     } catch (const std::exception & err) {
-         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
- #ifndef NDEBUG
-@@ -886,7 +898,7 @@ struct llama_data_write {
-     }
- 
-     void write_logits(const struct llama_context * ctx) {
--        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens());
-+        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
- 
-         write(&logits_size, sizeof(logits_size));
- 
+ void llama_synchronize(llama_context * ctx) {
+     ctx->synchronize();
+ }
 diff --git a/src/llama-context.h b/src/llama-context.h
-index a9268b29..cf12c9d7 100644
+index 04facb54..baa03276 100644
 --- a/src/llama-context.h
 +++ b/src/llama-context.h
-@@ -107,6 +107,8 @@ struct llama_context {
-     struct ggml_tensor * inp_pos_bucket;    // I32 [n_batch|n_kv, n_batch]
-     struct ggml_tensor * inp_embd_enc;      // F32 [n_embd, n_outputs_enc]
-     struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
-+
-+    struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
- };
+@@ -65,6 +65,7 @@ struct llama_context {
+     void set_embeddings (bool value);
+     void set_causal_attn(bool value);
+     void set_warmup(bool value);
++    void set_cross_attn(bool value);
  
- // TODO: make these methods of llama_context
+     void set_adapter_lora(
+             llama_adapter_lora * adapter,
 diff --git a/src/llama-cparams.h b/src/llama-cparams.h
-index 252012f3..9681e5a0 100644
+index 30e550f0..85ad91b9 100644
 --- a/src/llama-cparams.h
 +++ b/src/llama-cparams.h
 @@ -29,6 +29,7 @@ struct llama_cparams {
@@ -349,37 +422,115 @@ index 252012f3..9681e5a0 100644
      bool flash_attn;
      bool no_perf;
 +    bool cross_attn;
+     bool warmup;
  
      enum llama_pooling_type pooling_type;
+diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
+index cd955d63..83f3c5a8 100644
+--- a/src/llama-graph.cpp
++++ b/src/llama-graph.cpp
+@@ -546,6 +546,12 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
+     }
+ }
  
++void llm_graph_input_cross_attn_state::set_input(const llama_ubatch * ubatch) {
++    if (ubatch->embd) {
++        ggml_backend_tensor_set(cross_attn_state, ubatch->embd, 0, ggml_nbytes(cross_attn_state));
++    }
++}
++
+ //
+ // llm_graph_context
+ //
+@@ -1495,6 +1501,25 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
+     return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
+ }
+ 
++ggml_tensor * llm_graph_context::build_inp_cross_attn_state() const {
++    const int64_t n_embd = hparams.n_embd;
++
++    auto inp = std::make_unique();
++
++    ggml_tensor * cur = nullptr;
++
++    inp->cross_attn_state = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, 1601, 4);
++    ggml_set_input(inp->cross_attn_state);
++
++    cur = inp->cross_attn_state;
++
++    cb(cur, "inp_cross_attn_state", -1);
++
++    res->add_input(std::move(inp));
++
++    return cur;
++}
++
+ ggml_tensor * llm_graph_context::build_attn(
+         llm_graph_input_attn_cross * inp,
+         ggml_cgraph * gf,
+diff --git a/src/llama-graph.h b/src/llama-graph.h
+index 5b6618f9..51993998 100644
+--- a/src/llama-graph.h
++++ b/src/llama-graph.h
+@@ -86,6 +86,7 @@ public:
+ 
+     ggml_tensor * tokens = nullptr; // I32 [n_batch]
+     ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
++    ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
+ };
+ 
+ class llm_graph_input_pos : public llm_graph_input_i {
+@@ -285,6 +286,16 @@ public:
+     const llama_cross * cross = nullptr;
+ };
+ 
++class llm_graph_input_cross_attn_state : public llm_graph_input_i {
++public:
++    llm_graph_input_cross_attn_state()          = default;
++    virtual ~llm_graph_input_cross_attn_state() = default;
++
++    void set_input(const llama_ubatch * ubatch) override;
++
++    ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
++};
++
+ //
+ // llm_graph_result
+ //
+@@ -493,6 +504,7 @@ struct llm_graph_context {
+     ggml_tensor * build_inp_cls() const;
+     ggml_tensor * build_inp_s_copy() const;
+     ggml_tensor * build_inp_s_mask() const;
++    ggml_tensor * build_inp_cross_attn_state() const;
+ 
+     ggml_tensor * build_inp_cross_embd() const;
+     ggml_tensor * build_inp_pos_bucket_enc() const;
 diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
-index f3955de9..0b841028 100644
+index 8a667960..6a02de03 100644
 --- a/src/llama-hparams.cpp
 +++ b/src/llama-hparams.cpp
-@@ -2,6 +2,8 @@
- 
- #include "ggml.h"
- 
-+#include 
-+
- uint32_t llama_hparams::n_head(uint32_t il) const {
-     if (il < n_layer) {
-         return n_head_arr[il];
-@@ -76,4 +78,8 @@ bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const {
-     }
+@@ -85,3 +85,7 @@ bool llama_hparams::is_swa(uint32_t il) const {
  
      GGML_ABORT("fatal error");
-+}
+ }
 +
 +bool llama_hparams::cross_attention_layers(uint32_t il) const {
 +    return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
- }
-\ No newline at end of file
++}
 diff --git a/src/llama-hparams.h b/src/llama-hparams.h
-index 1bdcdfd5..05383046 100644
+index c3147cbc..4567a0e9 100644
 --- a/src/llama-hparams.h
 +++ b/src/llama-hparams.h
-@@ -41,6 +41,7 @@ struct llama_hparams {
+@@ -2,6 +2,8 @@
+ 
+ #include "llama.h"
+ 
++#include 
++
+ #include 
+ 
+ // bump if necessary
+@@ -42,6 +44,7 @@ struct llama_hparams {
      uint32_t n_expert = 0;
      uint32_t n_expert_used = 0;
      uint32_t n_rel_attn_bkts = 0;
@@ -387,7 +538,7 @@ index 1bdcdfd5..05383046 100644
  
      // for WavTokenizer
      struct llama_hparams_posnet   posnet;
-@@ -51,6 +52,7 @@ struct llama_hparams {
+@@ -52,6 +55,7 @@ struct llama_hparams {
      std::array n_ff_arr;
  
      std::array, 4> n_bskcn_arr = {};
@@ -395,21 +546,21 @@ index 1bdcdfd5..05383046 100644
  
      uint32_t n_layer_dense_lead = 0;
      uint32_t n_lora_q           = 0;
-@@ -138,6 +140,9 @@ struct llama_hparams {
- 
+@@ -154,6 +158,9 @@ struct llama_hparams {
      // Block skip connection
      bool n_bskcn(uint32_t n, uint32_t il) const;
-+
+ 
 +    // cross attention layers
 +    bool cross_attention_layers(uint32_t il) const;
++
+     bool is_swa(uint32_t il) const;
  };
  
- static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable");
 diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
-index feffdf0d..b541c5a3 100644
+index dbf5f118..9310f262 100644
 --- a/src/llama-kv-cache.cpp
 +++ b/src/llama-kv-cache.cpp
-@@ -91,8 +91,17 @@ bool llama_kv_cache_init(
+@@ -95,8 +95,16 @@ bool llama_kv_cache_unified::init(
              return false;
          }
  
@@ -425,12 +576,11 @@ index feffdf0d..b541c5a3 100644
 +            k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
 +            v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
 +        }
-+
          ggml_format_name(k, "cache_k_l%d", i);
          ggml_format_name(v, "cache_v_l%d", i);
-         cache.k_l.push_back(k);
+         k_l.push_back(k);
 diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
-index 1252aca1..45d08721 100644
+index a012aeae..2e11507d 100644
 --- a/src/llama-model-loader.cpp
 +++ b/src/llama-model-loader.cpp
 @@ -315,6 +315,8 @@ namespace GGUFMeta {
@@ -443,10 +593,10 @@ index 1252aca1..45d08721 100644
      bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) {
          const int kid = gguf_find_key(meta.get(), key.c_str());
 diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index ad1315c6..21819080 100644
+index 5fbd0055..d5ad466e 100644
 --- a/src/llama-model.cpp
 +++ b/src/llama-model.cpp
-@@ -401,6 +401,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+@@ -419,6 +419,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
  
      // get general kv
      ml.get_key(LLM_KV_GENERAL_NAME, name, false);
@@ -454,7 +604,7 @@ index ad1315c6..21819080 100644
  
      // everything past this point is not vocab-related
      if (hparams.vocab_only) {
-@@ -412,6 +413,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+@@ -430,6 +431,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
      ml.get_key(LLM_KV_BLOCK_COUNT,       hparams.n_layer);
      ml.get_key(LLM_KV_EXPERT_COUNT,      hparams.n_expert,      false);
      ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
@@ -462,7 +612,7 @@ index ad1315c6..21819080 100644
  
      if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
          ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
-@@ -435,9 +437,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+@@ -453,9 +455,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
      std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
      std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
      std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
@@ -474,7 +624,7 @@ index ad1315c6..21819080 100644
  
      // n_head_kv is optional, default to n_head
      hparams.n_head_kv_arr = hparams.n_head_arr;
-@@ -486,7 +490,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+@@ -508,7 +512,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
  
          ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
  
@@ -483,8 +633,8 @@ index ad1315c6..21819080 100644
              if (hparams.n_rot != hparams.n_embd_head_k) {
                  throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
              }
-@@ -530,6 +534,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
-                     }
+@@ -571,6 +575,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+                     hparams.use_kq_norm = false;
                  }
              } break;
 +        case LLM_ARCH_MLLAMA:
@@ -500,7 +650,7 @@ index ad1315c6..21819080 100644
          case LLM_ARCH_DECI:
              {
                  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-@@ -1398,7 +1412,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+@@ -1548,7 +1562,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
          const int64_t n_embd_head_v = hparams.n_embd_head_v;
          const int64_t n_ff          = hparams.n_ff();
          const int64_t n_embd_gqa    = n_embd_v_gqa;
@@ -509,7 +659,7 @@ index ad1315c6..21819080 100644
          const int64_t n_token_types = vocab.n_token_types();
          const int64_t n_rot         = hparams.n_rot;
          const int64_t n_expert      = hparams.n_expert;
-@@ -1581,6 +1595,52 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+@@ -1801,6 +1815,52 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                          }
                      }
                  } break;
@@ -562,107 +712,12 @@ index ad1315c6..21819080 100644
              case LLM_ARCH_DECI:
                  {
                      tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-@@ -3925,6 +3985,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
- 
-         // use what we call a normal RoPE, operating on pairs of consecutive head values
-         case LLM_ARCH_LLAMA:
-+        case LLM_ARCH_MLLAMA:
-         case LLM_ARCH_DECI:
-         case LLM_ARCH_BAICHUAN:
-         case LLM_ARCH_STARCODER:
-diff --git a/src/llama-model.h b/src/llama-model.h
-index 1afb0024..7cf57587 100644
---- a/src/llama-model.h
-+++ b/src/llama-model.h
-@@ -9,6 +9,7 @@
- #include 
- #include 
- #include 
-+#include 
- 
- struct llama_model_loader;
- 
-@@ -63,6 +64,7 @@ enum llm_type {
-     LLM_TYPE_40B,
-     LLM_TYPE_65B,
-     LLM_TYPE_70B,
-+    LLM_TYPE_90B,
-     LLM_TYPE_236B,
-     LLM_TYPE_314B,
-     LLM_TYPE_671B,
-@@ -284,6 +286,16 @@ struct llama_layer {
- 
-     struct ggml_tensor * bskcn_tv = nullptr;
- 
-+    // cross attention
-+    struct ggml_tensor * cross_attn_k_norm = nullptr;
-+    struct ggml_tensor * cross_attn_k_proj = nullptr;
-+    struct ggml_tensor * cross_attn_o_proj = nullptr;
-+    struct ggml_tensor * cross_attn_q_norm = nullptr;
-+    struct ggml_tensor * cross_attn_q_proj = nullptr;
-+    struct ggml_tensor * cross_attn_v_proj = nullptr;
-+    struct ggml_tensor * cross_attn_attn_gate = nullptr;
-+    struct ggml_tensor * cross_attn_mlp_gate = nullptr;
-+
-     struct llama_layer_posnet posnet;
- 
-     struct llama_layer_convnext convnext;
-diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
-index fb798265..6eb1da08 100644
---- a/src/llama-quant.cpp
-+++ b/src/llama-quant.cpp
-@@ -632,7 +632,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
-         if (llama_model_has_encoder(&model)) {
-             n_attn_layer *= 3;
-         }
--        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
-+        if (qs.n_attention_wv != n_attn_layer) {
-+            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
-+        }
+@@ -4665,6 +4725,246 @@ struct llm_build_llama : public llm_graph_context {
      }
+ };
  
-     size_t total_size_org = 0;
-diff --git a/src/llama.cpp b/src/llama.cpp
-index 6d320ea4..8f7902df 100644
---- a/src/llama.cpp
-+++ b/src/llama.cpp
-@@ -154,6 +154,21 @@ static struct ggml_tensor * llm_build_inp_embd(
-     return inpL;
- }
- 
-+static struct ggml_tensor * llm_build_inp_cross_attn_state(
-+        struct ggml_context * ctx,
-+       struct llama_context & lctx,
-+        const llama_hparams & hparams,
-+         const llm_build_cb & cb) {
-+    const int64_t n_embd = hparams.n_embd;
-+
-+    struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
-+    cb(inpCAS, "inp_cross_attn_state", -1);
-+    ggml_set_input(inpCAS);
-+    lctx.inp_cross_attn_state = inpCAS;
-+
-+    return inpCAS;
-+}
-+
- static void llm_build_kv_store(
-         struct ggml_context * ctx,
-         const llama_hparams & hparams,
-@@ -1157,6 +1172,7 @@ struct llm_build_context {
-         lctx.inp_pos_bucket    = nullptr;
-         lctx.inp_embd_enc      = nullptr;
-         lctx.inp_KQ_mask_cross = nullptr;
-+        lctx.inp_cross_attn_state = nullptr;
-     }
- 
-     void free() {
-@@ -1639,6 +1655,240 @@ struct llm_build_context {
-         return gf;
-     }
- 
-+    struct ggml_cgraph * build_mllama() {
-+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-+
++struct llm_build_mllama: public llm_graph_context {
++    llm_build_mllama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
 +        // mutable variable, needed during the last layer of the computation to skip unused tokens
 +        int32_t n_tokens = this->n_tokens;
 +
@@ -670,26 +725,26 @@ index 6d320ea4..8f7902df 100644
 +        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 +        GGML_ASSERT(n_embd_head == hparams.n_rot);
 +
-+        struct ggml_tensor * cur;
-+        struct ggml_tensor * inpL;
-+        struct ggml_tensor * inpCAS;
++        ggml_tensor * cur;
++        ggml_tensor * inpL;
++        ggml_tensor * inpCAS;
 +
-+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-+        inpCAS = llm_build_inp_cross_attn_state(ctx0, lctx, hparams, cb);
++        inpL = build_inp_embd(model.tok_embd);
++        inpCAS = build_inp_cross_attn_state();
 +
-+        // inp_pos - contains the positions
-+        struct ggml_tensor * inp_pos = build_inp_pos();
++          // inp_pos - contains the positions
++        ggml_tensor * inp_pos = build_inp_pos();
 +
-+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
++        auto * inp_attn = build_attn_inp_kv_unified();
++        const llama_kv_cache_unified * kv_self = static_cast(memory);
 +
 +        for (int il = 0; il < n_layer; ++il) {
-+            struct ggml_tensor * inpSA = inpL;
++            ggml_tensor * inpSA = inpL;
 +
 +            // norm
-+            cur = llm_build_norm(ctx0, inpL, hparams,
++            cur = build_norm(inpL,
 +                    model.layers[il].attn_norm, NULL,
-+                    LLM_NORM_RMS, cb, il);
++                    LLM_NORM_RMS, il);
 +            cb(cur, "attn_norm", il);
 +
 +            if (hparams.cross_attention_layers(il)) {
@@ -698,7 +753,7 @@ index 6d320ea4..8f7902df 100644
 +                }
 +
 +                // cross attention layer
-+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
++                ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
 +                cb(Qcur, "Qcur", il);
 +
 +                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
@@ -707,10 +762,10 @@ index 6d320ea4..8f7902df 100644
 +                Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
 +                cb(Qcur, "Qcur", il);
 +
-+                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
++                Qcur = build_norm(Qcur, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, il);
 +                cb(Qcur, "Qcur", il);
 +
-+                struct ggml_tensor * Kcur, * Vcur;
++                ggml_tensor * Kcur, * Vcur;
 +                if (ubatch.embd) {
 +                    Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
 +                    cb(Kcur, "Kcur", il);
@@ -721,10 +776,10 @@ index 6d320ea4..8f7902df 100644
 +                    Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
 +                    cb(Kcur, "Kcur", il);
 +
-+                    Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
++                    Kcur = build_norm(Kcur, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, il);
 +                    cb(Kcur, "Kcur", il);
 +
-+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
++                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self->k_l[il]));
 +
 +                    Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
 +                    cb(Vcur, "Vcur", il);
@@ -735,12 +790,12 @@ index 6d320ea4..8f7902df 100644
 +                    Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
 +                    cb(Vcur, "Vcur", il);
 +
-+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
++                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self->v_l[il]));
 +                } else {
-+                    Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
++                    Kcur = ggml_view_tensor(ctx0, kv_self->k_l[il]);
 +                    cb(Kcur, "Kcur (view)", il);
 +
-+                    Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
++                    Vcur = ggml_view_tensor(ctx0, kv_self->v_l[il]);
 +                    cb(Vcur, "Vcur (view)", il);
 +                }
 +
@@ -773,24 +828,24 @@ index 6d320ea4..8f7902df 100644
 +                cb(ffn_inp, "ffn_inp", il);
 +
 +                // feed-forward network
-+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
++                cur = build_norm(ffn_inp,
 +                        model.layers[il].ffn_norm, NULL,
-+                        LLM_NORM_RMS, cb, il);
++                        LLM_NORM_RMS, il);
 +                cb(cur, "ffn_norm", il);
 +
-+                cur = llm_build_ffn(ctx0, lctx, cur,
++                cur = build_ffn(cur,
 +                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
 +                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
 +                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
 +                        NULL,
-+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
++                        LLM_FFN_SILU, LLM_FFN_PAR, il);
 +                cb(cur, "ffn_out", il);
 +
 +                // TODO: do this inplace once?
 +                cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp);
 +                cb(cur, "ffn_out", il);
 +
-+                cur = lctx.cvec.apply_to(ctx0, cur, il);
++                cur = build_cvec(cur, il);
 +                cb(cur, "l_out", il);
 +
 +                // input for next layer
@@ -799,48 +854,53 @@ index 6d320ea4..8f7902df 100644
 +                // self attention layer
 +
 +                // rope freq factors for llama3; may return nullptr for llama2 and other models
-+                struct ggml_tensor * rope_factors = build_rope_factors(il);
++                ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
 +
 +                // compute Q and K and RoPE them
-+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
++                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 +                cb(Qcur, "Qcur", il);
 +                if (model.layers[il].bq) {
 +                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
 +                    cb(Qcur, "Qcur", il);
 +                }
 +
-+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
++                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
 +                cb(Kcur, "Kcur", il);
 +                if (model.layers[il].bk) {
 +                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
 +                    cb(Kcur, "Kcur", il);
 +                }
 +
-+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
++                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
 +                cb(Vcur, "Vcur", il);
 +                if (model.layers[il].bv) {
 +                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
 +                    cb(Vcur, "Vcur", il);
 +                }
 +
++                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
++                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
++                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
++
 +                Qcur = ggml_rope_ext(
-+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-+                    ext_factor, attn_factor, beta_fast, beta_slow
-+                );
-+                cb(Qcur, "Qcur", il);
++                        ctx0, Qcur, inp_pos, rope_factors,
++                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
++                        ext_factor, attn_factor, beta_fast, beta_slow
++                        );
 +
 +                Kcur = ggml_rope_ext(
-+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-+                    ext_factor, attn_factor, beta_fast, beta_slow
-+                );
++                        ctx0, Kcur, inp_pos, rope_factors,
++                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
++                        ext_factor, attn_factor, beta_fast, beta_slow
++                        );
++                
++                cb(Qcur, "Qcur", il);
 +                cb(Kcur, "Kcur", il);
++                cb(Vcur, "Vcur", il);
 +
-+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
++                cur = build_attn(inp_attn, gf,
 +                    model.layers[il].wo, model.layers[il].bo,
-+                    Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-+
++                    Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
 +
 +                if (il == n_layer - 1) {
 +                    // skip computing output for unused tokens
@@ -854,23 +914,23 @@ index 6d320ea4..8f7902df 100644
 +                cb(ffn_inp, "ffn_inp", il);
 +
 +                // feed-forward network
-+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
++                cur = build_norm(ffn_inp,
 +                        model.layers[il].ffn_norm, NULL,
-+                        LLM_NORM_RMS, cb, il);
++                        LLM_NORM_RMS, il);
 +                cb(cur, "ffn_norm", il);
 +
-+                cur = llm_build_ffn(ctx0, lctx, cur,
++                cur = build_ffn(cur,
 +                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
 +                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
 +                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
 +                        NULL,
-+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
++                        LLM_FFN_SILU, LLM_FFN_PAR, il);
 +                cb(cur, "ffn_out", il);
 +
 +                cur = ggml_add(ctx0, cur, ffn_inp);
 +                cb(cur, "ffn_out", il);
 +
-+                cur = lctx.cvec.apply_to(ctx0, cur, il);
++                cur = build_cvec(cur, il);
 +                cb(cur, "l_out", il);
 +
 +                // input for next layer
@@ -880,74 +940,93 @@ index 6d320ea4..8f7902df 100644
 +
 +        cur = inpL;
 +
-+        cur = llm_build_norm(ctx0, cur, hparams,
++        cur = build_norm(cur,
 +                model.output_norm, NULL,
-+                LLM_NORM_RMS, cb, -1);
++                LLM_NORM_RMS, -1);
 +        cb(cur, "result_norm", -1);
++        res->t_embd = cur;
 +
 +        // lm_head
-+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
++        cur = build_lora_mm(model.output, cur);
++
 +        cb(cur, "result_output", -1);
++        res->t_logits = cur;
 +
 +        ggml_build_forward_expand(gf, cur);
-+
-+        return gf;
 +    }
++};
 +
-     struct ggml_cgraph * build_deci() {
-         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
- 
-@@ -8344,6 +8594,10 @@ static struct ggml_cgraph * llama_build_graph(
+ struct llm_build_deci : public llm_graph_context {
+     llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+         const int64_t n_embd_head = hparams.n_embd_head_v;
+@@ -12965,6 +13265,10 @@ llm_graph_result_ptr llama_model::build_graph(
              {
-                 result = llm.build_llama();
+                 llm = std::make_unique(*this, params, gf);
              } break;
 +        case LLM_ARCH_MLLAMA:
 +            {
-+                result = llm.build_mllama();
++                llm = std::make_unique(*this, params, gf);
 +            } break;
          case LLM_ARCH_DECI:
              {
-                 result = llm.build_deci();
-@@ -8634,7 +8888,7 @@ static int llama_prepare_sbatch(
-         n_outputs = 1;
+                 llm = std::make_unique(*this, params, gf);
+@@ -13325,6 +13629,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
+         // use what we call a normal RoPE, operating on pairs of consecutive head values
+         case LLM_ARCH_LLAMA:
+         case LLM_ARCH_LLAMA4:
++        case LLM_ARCH_MLLAMA:
+         case LLM_ARCH_DECI:
+         case LLM_ARCH_BAICHUAN:
+         case LLM_ARCH_STARCODER:
+diff --git a/src/llama-model.h b/src/llama-model.h
+index e08d4ae4..21c4617b 100644
+--- a/src/llama-model.h
++++ b/src/llama-model.h
+@@ -11,6 +11,7 @@
+ #include 
+ #include 
+ #include 
++#include 
+ 
+ struct llama_cparams;
+ struct llama_ubatch;
+@@ -70,6 +71,7 @@ enum llm_type {
+     LLM_TYPE_40B,
+     LLM_TYPE_65B,
+     LLM_TYPE_70B,
++    LLM_TYPE_90B,
+     LLM_TYPE_236B,
+     LLM_TYPE_314B,
+     LLM_TYPE_671B,
+@@ -308,6 +310,16 @@ struct llama_layer {
+ 
+     struct ggml_tensor * bskcn_tv = nullptr;
+ 
++    // cross attention
++    struct ggml_tensor * cross_attn_k_norm = nullptr;
++    struct ggml_tensor * cross_attn_k_proj = nullptr;
++    struct ggml_tensor * cross_attn_o_proj = nullptr;
++    struct ggml_tensor * cross_attn_q_norm = nullptr;
++    struct ggml_tensor * cross_attn_q_proj = nullptr;
++    struct ggml_tensor * cross_attn_v_proj = nullptr;
++    struct ggml_tensor * cross_attn_attn_gate = nullptr;
++    struct ggml_tensor * cross_attn_mlp_gate = nullptr;
++
+     struct llama_layer_posnet posnet;
+ 
+     struct llama_layer_convnext convnext;
+diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
+index 7dc54227..223e1f3f 100644
+--- a/src/llama-quant.cpp
++++ b/src/llama-quant.cpp
+@@ -639,7 +639,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
+         if (llama_model_has_encoder(&model)) {
+             n_attn_layer *= 3;
+         }
+-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
++        if (qs.n_attention_wv != n_attn_layer) {
++            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
++        }
      }
  
--    lctx.sbatch.from_batch(batch, n_embd,
-+    lctx.sbatch.from_batch(batch, batch.n_embd,
-         /* simple_split */ !lctx.kv_self.recurrent,
-         /* logits_all   */ n_outputs == n_tokens_all);
- 
-@@ -8749,7 +9003,6 @@ static int llama_decode_impl(
-     const llama_batch & batch = batch_allocr.batch;
- 
-     const auto & model   = lctx.model;
--    const auto & vocab   = model.vocab;
-     const auto & hparams = model.hparams;
-     const auto & cparams = lctx.cparams;
- 
-@@ -8760,7 +9013,7 @@ static int llama_decode_impl(
-     llama_kv_slot_restorer kv_slot_restorer(kv_self);
- 
-     const int64_t n_embd  = hparams.n_embd;
--    const int64_t n_vocab = vocab.n_tokens();
-+    const int64_t n_vocab = hparams.n_vocab;
- 
-     uint32_t n_outputs = 0;
-     uint32_t n_outputs_prev = 0;
-@@ -9025,7 +9278,7 @@ static int llama_encode_impl(
- 
-     const int64_t n_embd = hparams.n_embd;
- 
--    lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
-+    lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
- 
-     const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
- 
-@@ -9511,6 +9764,7 @@ struct llama_context_params llama_context_default_params() {
-         /*.offload_kqv                 =*/ true,
-         /*.flash_attn                  =*/ false,
-         /*.no_perf                     =*/ true,
-+        /*.cross_attn                  =*/ false,
-         /*.abort_callback              =*/ nullptr,
-         /*.abort_callback_data         =*/ nullptr,
-     };
+     size_t total_size_org = 0;
diff --git a/llama/patches/0008-add-unpad-operator.patch b/llama/patches/0008-add-unpad-operator.patch
index 31e67a6e3..ce409ce9f 100644
--- a/llama/patches/0008-add-unpad-operator.patch
+++ b/llama/patches/0008-add-unpad-operator.patch
@@ -1,24 +1,27 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Michael Yang 
-Date: Thu, 17 Oct 2024 17:19:25 -0700
+From: jmorganca 
+Date: Sun, 13 Apr 2025 22:10:06 -0400
 Subject: [PATCH] add unpad operator
 
+adds the unpad operator to GGML
 ---
  ggml/include/ggml.h                  | 10 +++++
- ggml/src/ggml-cpu/ggml-cpu.c         | 58 ++++++++++++++++++++++++++++
+ ggml/src/ggml-cpu/ggml-cpu.c         |  5 +++
+ ggml/src/ggml-cpu/ops.cpp            | 55 ++++++++++++++++++++++++++++
+ ggml/src/ggml-cpu/ops.h              |  1 +
  ggml/src/ggml-cuda/ggml-cuda.cu      |  4 ++
- ggml/src/ggml-cuda/pad.cu            | 46 ++++++++++++++++++++++
+ ggml/src/ggml-cuda/pad.cu            | 46 +++++++++++++++++++++++
  ggml/src/ggml-cuda/pad.cuh           |  1 +
- ggml/src/ggml-metal/ggml-metal.m     | 33 ++++++++++++++++
- ggml/src/ggml-metal/ggml-metal.metal | 45 +++++++++++++++++++++
- ggml/src/ggml.c                      | 25 +++++++++++-
- 8 files changed, 220 insertions(+), 2 deletions(-)
+ ggml/src/ggml-metal/ggml-metal.m     | 33 +++++++++++++++++
+ ggml/src/ggml-metal/ggml-metal.metal | 45 +++++++++++++++++++++++
+ ggml/src/ggml.c                      | 25 ++++++++++++-
+ 10 files changed, 223 insertions(+), 2 deletions(-)
 
 diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
-index dd0c6a96..8d269a9c 100644
+index 8fcc16df..d19fc167 100644
 --- a/ggml/include/ggml.h
 +++ b/ggml/include/ggml.h
-@@ -487,6 +487,7 @@ extern "C" {
+@@ -488,6 +488,7 @@ extern "C" {
          GGML_OP_UPSCALE, // nearest interpolate
          GGML_OP_PAD,
          GGML_OP_PAD_REFLECT_1D,
@@ -26,7 +29,7 @@ index dd0c6a96..8d269a9c 100644
          GGML_OP_ARANGE,
          GGML_OP_TIMESTEP_EMBEDDING,
          GGML_OP_ARGSORT,
-@@ -1743,6 +1744,15 @@ extern "C" {
+@@ -1757,6 +1758,15 @@ extern "C" {
              int                   p0,
              int                   p1);
  
@@ -43,13 +46,38 @@ index dd0c6a96..8d269a9c 100644
      // timesteps: [N,]
      // return: [N, dim]
 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
-index 72325349..2f606d82 100644
+index 50400328..432942bf 100644
 --- a/ggml/src/ggml-cpu/ggml-cpu.c
 +++ b/ggml/src/ggml-cpu/ggml-cpu.c
-@@ -10844,6 +10844,59 @@ static void ggml_compute_forward_pad_reflect_1d(
+@@ -1960,6 +1960,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
+             {
+                 ggml_compute_forward_pad_reflect_1d(params, tensor);
+             } break;
++        case GGML_OP_UNPAD:
++            {
++                ggml_compute_forward_unpad(params, tensor);
++            } break;
+         case GGML_OP_ARANGE:
+             {
+                 ggml_compute_forward_arange(params, tensor);
+@@ -2282,6 +2286,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
+         case GGML_OP_UPSCALE:
+         case GGML_OP_PAD:
+         case GGML_OP_PAD_REFLECT_1D:
++        case GGML_OP_UNPAD:
+         case GGML_OP_ARANGE:
+         case GGML_OP_TIMESTEP_EMBEDDING:
+         case GGML_OP_ARGSORT:
+diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
+index 6050147b..66b8da68 100644
+--- a/ggml/src/ggml-cpu/ops.cpp
++++ b/ggml/src/ggml-cpu/ops.cpp
+@@ -6531,6 +6531,61 @@ void ggml_compute_forward_pad_reflect_1d(
      }
  }
  
++// ggml_compute_forward_unpad
++
 +static void ggml_compute_forward_unpad_f32(
 +    const struct ggml_compute_params *params,
 +    struct ggml_tensor *dst) {
@@ -85,7 +113,7 @@ index 72325349..2f606d82 100644
 +    }
 +}
 +
-+static void ggml_compute_forward_unpad(
++void ggml_compute_forward_unpad(
 +    const struct ggml_compute_params * params,
 +    struct ggml_tensor * dst) {
 +
@@ -106,30 +134,23 @@ index 72325349..2f606d82 100644
  // ggml_compute_forward_arange
  
  static void ggml_compute_forward_arange_f32(
-@@ -13137,6 +13190,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
-             {
-                 ggml_compute_forward_pad_reflect_1d(params, tensor);
-             } break;
-+        case GGML_OP_UNPAD:
-+            {
-+                ggml_compute_forward_unpad(params, tensor);
-+            } break;
-         case GGML_OP_ARANGE:
-             {
-                 ggml_compute_forward_arange(params, tensor);
-@@ -13484,6 +13541,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
-         case GGML_OP_UPSCALE:
-         case GGML_OP_PAD:
-         case GGML_OP_PAD_REFLECT_1D:
-+        case GGML_OP_UNPAD:
-         case GGML_OP_ARANGE:
-         case GGML_OP_TIMESTEP_EMBEDDING:
-         case GGML_OP_ARGSORT:
+diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h
+index 410a3720..3eca1cf8 100644
+--- a/ggml/src/ggml-cpu/ops.h
++++ b/ggml/src/ggml-cpu/ops.h
+@@ -71,6 +71,7 @@ void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params
+ void ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+ void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+ void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
++void ggml_compute_forward_unpad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+ void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+ void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+ void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index 1b0d074b..c7a957c8 100644
+index b70c6a32..67208cba 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -2200,6 +2200,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
+@@ -2245,6 +2245,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
          case GGML_OP_PAD:
              ggml_cuda_op_pad(ctx, dst);
              break;
@@ -139,16 +160,16 @@ index 1b0d074b..c7a957c8 100644
          case GGML_OP_ARANGE:
              ggml_cuda_op_arange(ctx, dst);
              break;
-@@ -3199,6 +3202,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
-             return ggml_is_contiguous(op->src[0]);
+@@ -3223,6 +3226,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
          case GGML_OP_UPSCALE:
+             return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
          case GGML_OP_PAD:
 +        case GGML_OP_UNPAD:
          case GGML_OP_ARANGE:
          case GGML_OP_TIMESTEP_EMBEDDING:
          case GGML_OP_LEAKY_RELU:
 diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
-index aba539e8..b4b87409 100644
+index 77432b04..7d45a7e1 100644
 --- a/ggml/src/ggml-cuda/pad.cu
 +++ b/ggml/src/ggml-cuda/pad.cu
 @@ -47,3 +47,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -212,10 +233,10 @@ index 8fd386b0..e2ededc3 100644
  void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 +void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
-index fd9a4e77..e4c093f9 100644
+index 310afe8a..b121ab9e 100644
 --- a/ggml/src/ggml-metal/ggml-metal.m
 +++ b/ggml/src/ggml-metal/ggml-metal.m
-@@ -331,6 +331,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
+@@ -341,6 +341,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
      GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
      GGML_METAL_KERNEL_TYPE_PAD_F32,
      GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
@@ -223,23 +244,23 @@ index fd9a4e77..e4c093f9 100644
      GGML_METAL_KERNEL_TYPE_ARANGE_F32,
      GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
      GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
-@@ -946,6 +947,7 @@ @implementation GGMLMetalClass
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,            pad_reflect_1d_f32,             true);
-+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32,                     unpad_f32,                      true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
-@@ -1254,6 +1256,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
-         case GGML_OP_UPSCALE:
+@@ -998,6 +999,7 @@ @implementation GGMLMetalClass
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                     upscale_f32,                     true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                         pad_f32,                         true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,              pad_reflect_1d_f32,              true);
++        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32,                       unpad_f32,                       true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,          timestep_embedding_f32,          true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                      arange_f32,                      true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,             argsort_f32_i32_asc,             true);
+@@ -1339,6 +1341,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
+         case GGML_OP_POOL_2D:
          case GGML_OP_PAD:
          case GGML_OP_PAD_REFLECT_1D:
 +        case GGML_OP_UNPAD:
-         case GGML_OP_ARANGE:
          case GGML_OP_TIMESTEP_EMBEDDING:
          case GGML_OP_ARGSORT:
-@@ -3469,6 +3472,36 @@ static void ggml_metal_encode_node(
+         case GGML_OP_LEAKY_RELU:
+@@ -3669,6 +3672,36 @@ static void ggml_metal_encode_node(
  
                  const int nth = MIN(1024, ne0);
  
@@ -277,10 +298,10 @@ index fd9a4e77..e4c093f9 100644
              } break;
          case GGML_OP_ARANGE:
 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index d092a169..f38909d0 100644
+index b08666e2..e3185e5b 100644
 --- a/ggml/src/ggml-metal/ggml-metal.metal
 +++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -2953,6 +2953,51 @@ kernel void kernel_pad_reflect_1d_f32(
+@@ -2968,6 +2968,51 @@ kernel void kernel_pad_reflect_1d_f32(
      }
  }
  
@@ -331,12 +352,12 @@ index d092a169..f38909d0 100644
 +
  kernel void kernel_arange_f32(
      device        char * dst,
-     constant   int64_t & ne0,
+     constant   ggml_metal_kargs_arange & args,
 diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
-index 7fc06724..635aa299 100644
+index 950772c7..2276b631 100644
 --- a/ggml/src/ggml.c
 +++ b/ggml/src/ggml.c
-@@ -962,6 +962,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
+@@ -963,6 +963,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
      "UPSCALE",
      "PAD",
      "PAD_REFLECT_1D",
@@ -344,16 +365,16 @@ index 7fc06724..635aa299 100644
      "ARANGE",
      "TIMESTEP_EMBEDDING",
      "ARGSORT",
-@@ -996,7 +997,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
+@@ -993,7 +994,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
      "OPT_STEP_ADAMW",
  };
  
--static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
-+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
+-static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
++static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
  
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
      "none",
-@@ -1059,6 +1060,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+@@ -1057,6 +1058,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
      "upscale(x)",
      "pad(x)",
      "pad_reflect_1d(x)",
@@ -361,16 +382,16 @@ index 7fc06724..635aa299 100644
      "arange(start, stop, step)",
      "timestep_embedding(timesteps, dim, max_period)",
      "argsort(x)",
-@@ -1093,7 +1095,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+@@ -1087,7 +1089,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
      "adamw(x)",
  };
  
--static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
-+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
+-static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
++static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
  
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
  
-@@ -4225,6 +4227,25 @@ struct ggml_tensor * ggml_pad_reflect_1d(
+@@ -4262,6 +4264,25 @@ struct ggml_tensor * ggml_pad_reflect_1d(
      return result;
  }
  
diff --git a/llama/patches/0009-fix-deepseek-deseret-regex.patch b/llama/patches/0009-fix-deepseek-deseret-regex.patch
index 715c5206f..7c7336492 100644
--- a/llama/patches/0009-fix-deepseek-deseret-regex.patch
+++ b/llama/patches/0009-fix-deepseek-deseret-regex.patch
@@ -1,20 +1,21 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Daniel Hiltgen 
-Date: Fri, 25 Oct 2024 16:25:18 -0700
+From: jmorganca 
+Date: Tue, 8 Apr 2025 19:43:06 -0700
 Subject: [PATCH] fix deepseek deseret regex
 
-On windows compiled with gcc the c++ regex library failed to handle
-the characters
+on some systems, deepseek's regex would throw an error
+on windows due to the deseret characters in the matching
+regex
 ---
  src/llama-vocab.cpp |  2 +-
- src/unicode.cpp     | 22 ++++++++++++++++++++++
- 2 files changed, 23 insertions(+), 1 deletion(-)
+ src/unicode.cpp     | 21 +++++++++++++++++++++
+ 2 files changed, 22 insertions(+), 1 deletion(-)
 
 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index a4eee9b8..1ca827eb 100644
+index 0125ee53..d74919d2 100644
 --- a/src/llama-vocab.cpp
 +++ b/src/llama-vocab.cpp
-@@ -295,7 +295,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
+@@ -296,7 +296,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
              case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
                  regex_exprs = {
                      "[\r\n]",
@@ -24,7 +25,7 @@ index a4eee9b8..1ca827eb 100644
                      "\\s+$",
                      "[一-龥ࠀ-一가-퟿]+",
 diff --git a/src/unicode.cpp b/src/unicode.cpp
-index e63bb4ab..9dd53b9a 100644
+index e63bb4ab..73cb2b1a 100644
 --- a/src/unicode.cpp
 +++ b/src/unicode.cpp
 @@ -2,6 +2,11 @@
@@ -39,7 +40,7 @@ index e63bb4ab..9dd53b9a 100644
  #include "unicode.h"
  #include "unicode-data.h"
  
-@@ -200,6 +205,22 @@ static std::unordered_map unicode_utf8_to_byte_map() {
+@@ -200,6 +205,21 @@ static std::unordered_map unicode_utf8_to_byte_map() {
  }
  
  static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
@@ -58,11 +59,10 @@ index e63bb4ab..9dd53b9a 100644
 +    free(wbuf);
 +    return ret;
 +#else
-+
  #if defined(__clang__)
      // disable C++17 deprecation warning for std::codecvt_utf8
  #    pragma clang diagnostic push
-@@ -213,6 +234,7 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
+@@ -213,6 +233,7 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
  #endif
  
      return conv.from_bytes(s);
diff --git a/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch b/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch
index 1e930fb29..5504b1d31 100644
--- a/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch
+++ b/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch
@@ -1,14 +1,14 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: ParthSareen 
-Date: Wed, 11 Dec 2024 15:37:32 -0800
-Subject: [PATCH] Maintain ordering for rules for grammar
+From: jmorganca 
+Date: Tue, 8 Apr 2025 19:43:40 -0700
+Subject: [PATCH] maintain ordering for rules for grammar
 
 ---
  common/json-schema-to-grammar.cpp | 2 +-
  1 file changed, 1 insertion(+), 1 deletion(-)
 
 diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
-index 3ebcc3d9..30c28808 100644
+index 90679822..56043678 100644
 --- a/common/json-schema-to-grammar.cpp
 +++ b/common/json-schema-to-grammar.cpp
 @@ -346,7 +346,7 @@ private:
diff --git a/llama/patches/0011-ensure-KV-cache-is-fully-defragmented.patch b/llama/patches/0011-ensure-KV-cache-is-fully-defragmented.patch
new file mode 100644
index 000000000..f1fe118e6
--- /dev/null
+++ b/llama/patches/0011-ensure-KV-cache-is-fully-defragmented.patch
@@ -0,0 +1,361 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca 
+Date: Tue, 15 Apr 2025 14:27:40 -0400
+Subject: [PATCH] ensure KV cache is fully defragmented
+
+Sometimes the KV cache requires defragmentation even without
+triggering the threshold heuristic. In this case, decoding
+will not being able to find a KV cache slot. This is particularly
+difficult for the caller to handle if it happens in between
+ubatches. To avoid this, we should immediately trigger a defrag.
+
+In addition, a heavily fragmented cache can require more than
+max_moves to defragment. Currently, we stop when we hit the limit
+but this can leave a cache that still does not have adequate space
+even after defragmentation is triggered. Instead, we should do
+multiple batches of processing until everything is complete.
+---
+ src/llama-context.cpp  | 105 +++++++++++++----------------------------
+ src/llama-context.h    |   4 +-
+ src/llama-kv-cache.cpp |  39 +++------------
+ src/llama-kv-cache.h   |   9 +++-
+ 4 files changed, 51 insertions(+), 106 deletions(-)
+
+diff --git a/src/llama-context.cpp b/src/llama-context.cpp
+index afe6f552..d6e7b3af 100644
+--- a/src/llama-context.cpp
++++ b/src/llama-context.cpp
+@@ -590,13 +590,12 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
+ 
+ llm_graph_result_ptr llama_context::build_kv_self_defrag(
+         ggml_context * ctx0,
+-        ggml_cgraph * gf) const {
++        ggml_cgraph * gf,
++        const std::vector & moves) const {
+     auto res = std::make_unique();
+ 
+     const auto & hparams = model.hparams;
+ 
+-    const auto & ids = kv_self->defrag_info.ids;
+-
+ #if 0
+     // CPU defrag
+     //
+@@ -668,32 +667,20 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
+         ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
+     }
+ #else
+-    for (uint32_t i = 0; i < ids.size(); ++i) {
+-        const uint32_t id = ids[i];
+-
+-        if (i == id || id == ids.size()) {
+-            continue;
+-        }
+-
+-        uint32_t nm = 1;
+-
+-        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
+-            nm++;
+-        }
+-
++    for (const auto & move : moves) {
+         for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
+             const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+             const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+ 
+             ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
+-                    n_embd_k_gqa, nm,
++                    n_embd_k_gqa, move.len,
+                     ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
+-                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
++                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*move.src));
+ 
+             ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
+-                    n_embd_k_gqa, nm,
++                    n_embd_k_gqa, move.len,
+                     ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
+-                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
++                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*move.dst));
+ 
+             ggml_tensor * view_v_src;
+             ggml_tensor * view_v_dst;
+@@ -701,34 +688,30 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
+             if (cparams.flash_attn) {
+                 // NOTE: the V cache is not transposed when using flash attention
+                 view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
+-                        n_embd_v_gqa, nm,
++                        n_embd_v_gqa, move.len,
+                         ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
+-                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
++                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*move.src));
+ 
+                 view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
+-                        n_embd_v_gqa, nm,
++                        n_embd_v_gqa, move.len,
+                         ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
+-                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
++                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*move.dst));
+             } else {
+                 view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
+-                        nm, n_embd_v_gqa,
++                        move.len, n_embd_v_gqa,
+                         ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
+-                        ggml_row_size(kv_self->v_l[il]->type, i));
++                        ggml_row_size(kv_self->v_l[il]->type, move.src));
+ 
+                 view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
+-                        nm, n_embd_v_gqa,
++                        move.len, n_embd_v_gqa,
+                         ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
+-                        ggml_row_size(kv_self->v_l[il]->type, id));
++                        ggml_row_size(kv_self->v_l[il]->type, move.dst));
+             }
+ 
+             ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
+             ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
+         }
+-
+-        i += nm - 1;
+     }
+-
+-    //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
+ #endif
+ 
+     return res;
+@@ -737,8 +720,6 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
+ void llama_context::kv_self_update() {
+     auto & kv = kv_self;
+ 
+-    bool need_reserve = false;
+-
+     if (kv->has_shift) {
+         if (!kv->get_can_shift()) {
+             GGML_ABORT("The current context does not support K-shift");
+@@ -759,8 +740,6 @@ void llama_context::kv_self_update() {
+             res->set_inputs(nullptr);
+ 
+             graph_compute(gf, false);
+-
+-            need_reserve = true;
+         }
+ 
+         {
+@@ -775,49 +754,28 @@ void llama_context::kv_self_update() {
+     // defragment the KV cache if needed
+     if (kv->do_defrag) {
+         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
++        const uint32_t n_max_nodes = graph_max_nodes();
++        const uint32_t max_moves = (n_max_nodes - 2*model.hparams.n_layer)/(6*model.hparams.n_layer);
++        if (!kv->defrag_prepare(n_max_nodes)) {
++            LLAMA_LOG_ERROR("%s: failed to prepare defragmentation\n", __func__);
++            return;
++        }
+ 
+-        if (kv->defrag_prepare(graph_max_nodes())) {
+-            ggml_backend_sched_reset(sched.get());
++        for (std::size_t i = 0; i < kv_self->defrag_info.moves.size(); i += max_moves) {
++            std::vector chunk;
++            auto end = std::min(i + max_moves, kv_self->defrag_info.moves.size());
++            chunk.assign(kv_self->defrag_info.moves.begin() + i, kv_self->defrag_info.moves.begin() + end);
+ 
++            ggml_backend_sched_reset(sched.get());
+             auto * gf = graph_init();
+-
+-            auto res = build_kv_self_defrag(ctx_compute.get(), gf);
+-
++            auto res = build_kv_self_defrag(ctx_compute.get(), gf, chunk);
+             ggml_backend_sched_alloc_graph(sched.get(), gf);
+-
+             res->set_inputs(nullptr);
+-
+             graph_compute(gf, false);
+-
+-            need_reserve = true;
+         }
+ 
+         kv->do_defrag = false;
+     }
+-
+-    // reserve a worst case graph if needed
+-    if (need_reserve) {
+-        LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
+-
+-        // build worst-case graph
+-        uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+-        uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+-
+-        // simulate full KV cache
+-        kv_self->n = kv_self->size;
+-
+-        llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+-        llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+-
+-        auto * gf = graph_init();
+-        graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
+-
+-        // initialize scheduler with the worst-case graph
+-        ggml_backend_sched_reset(sched.get());
+-        if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+-            LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
+-        }
+-    }
+ }
+ 
+ enum llama_pooling_type llama_context::pooling_type() const {
+@@ -1301,9 +1259,12 @@ int llama_context::decode(llama_batch & inp_batch) {
+         // find KV slot
+         {
+             if (!kv_self->find_slot(ubatch)) {
+-                LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
+-
+-                return 1;
++                kv_self->defrag();
++                kv_self_update();
++                if (!kv_self->find_slot(ubatch)) {
++                    LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
++                    return 1;
++                }
+             }
+ 
+             if (!kv_self->recurrent) {
+diff --git a/src/llama-context.h b/src/llama-context.h
+index baa03276..a59ff8fd 100644
+--- a/src/llama-context.h
++++ b/src/llama-context.h
+@@ -5,6 +5,7 @@
+ #include "llama-cparams.h"
+ #include "llama-graph.h"
+ #include "llama-adapter.h"
++#include "llama-kv-cache.h"
+ 
+ #include "ggml-cpp.h"
+ 
+@@ -180,7 +181,8 @@ private:
+ 
+     llm_graph_result_ptr build_kv_self_defrag(
+             ggml_context * ctx0,
+-            ggml_cgraph * gf) const;
++            ggml_cgraph * gf,
++            const std::vector & moves) const;
+ 
+     // TODO: read/write lora adapters and cvec
+     size_t state_write_data(llama_io_write_i & io);
+diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
+index 9310f262..5c941e7c 100644
+--- a/src/llama-kv-cache.cpp
++++ b/src/llama-kv-cache.cpp
+@@ -781,17 +781,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
+ 
+     assert(n_used <= n_kv);
+ 
+-    //const int64_t t_start = ggml_time_us();
+-
+-    // number of cells moved
+-    uint32_t n_moves = 0;
+-
+-    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
+-    //   - source view, destination view, copy operation
+-    //   - x2 for keys and values
+-    //const uint32_t max_moves = max_nodes()/(6*n_layer);
+-    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
+-    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
++    defrag_info.moves.clear();
+ 
+     // determine which KV cells to move where
+     //
+@@ -799,10 +789,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
+     //
+     //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
+     //
+-    auto & ids = defrag_info.ids;
+-
+-    ids.clear();
+-    ids.resize(n_kv, n_kv);
++    std::vector ids(n_kv, n_kv);
+ 
+     for (uint32_t i0 = 0; i0 < n_used; ++i0) {
+         const auto & cell0 = cells[i0];
+@@ -851,19 +838,11 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
+         // are we moving a continuous block of memory?
+         bool cont = false;
+ 
+-        // should we stop searching for the next move?
+-        bool stop = false;
+-
+         // go back and move the nf cells to the hole
+         for (; i1 < n_kv; ++i1) {
+             auto & cell1 = cells[i1];
+ 
+             if (cell1.is_empty() || ids[i1] != n_kv) {
+-                if (n_moves == max_moves) {
+-                    stop = true;
+-                    break;
+-                }
+-
+                 cont = false;
+                 continue;
+             }
+@@ -879,8 +858,10 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
+             head = n_used;
+ 
+             if (!cont) {
+-                n_moves++;
++                defrag_info.moves.push_back({i1, i0 + nf, 1});
+                 cont = true;
++            } else {
++                defrag_info.moves.back().len++;
+             }
+ 
+             nf++;
+@@ -890,22 +871,16 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
+             }
+         }
+ 
+-        if (stop || n_moves == max_moves) {
+-            break;
+-        }
+-
+         //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
+ 
+         i0 += nh - 1;
+     }
+ 
+-    if (n_moves == 0) {
++    if (defrag_info.moves.size() == 0) {
+         return false;
+     }
+ 
+-    LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
+-
+-    LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer);
++    // LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
+ 
+     return true;
+ }
+diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
+index 56c74035..25cbcb56 100644
+--- a/src/llama-kv-cache.h
++++ b/src/llama-kv-cache.h
+@@ -43,6 +43,13 @@ private:
+     llama_kv_cache * kv;
+ };
+ 
++// block of KV slots to move when defragging
++struct llama_kv_defrag_move {
++    uint32_t src;
++    uint32_t dst;
++    uint32_t len;
++};
++
+ struct llama_kv_cell {
+     llama_pos pos   = -1;
+     llama_pos delta =  0;
+@@ -131,7 +138,7 @@ public:
+     // defrag
+ 
+     struct {
+-        std::vector ids;
++        std::vector moves;
+     } defrag_info;
+ 
+     // return true if cells have been moved
diff --git a/llama/patches/0011-llama-Ensure-KV-cache-is-fully-defragmented.patch b/llama/patches/0011-llama-Ensure-KV-cache-is-fully-defragmented.patch
deleted file mode 100644
index ff0575390..000000000
--- a/llama/patches/0011-llama-Ensure-KV-cache-is-fully-defragmented.patch
+++ /dev/null
@@ -1,242 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Jesse Gross 
-Date: Fri, 13 Dec 2024 16:11:59 -0800
-Subject: [PATCH] llama: Ensure KV cache is fully defragmented.
-
-Sometimes the KV cache requires defragmentation even without
-triggering the threshold heuristic. In this case, decoding
-will not being able to find a KV cache slot. This is particularly
-difficult for the caller to handle if it happens in between
-ubatches. To avoid this, we should immediately trigger a defrag.
-
-In addition, a heavily fragmented cache can require more than
-max_moves to defragment. Currently, we stop when we hit the limit
-but this can leave a cache that still does not have adequate space
-even after defragmentation is triggered. Instead, we should do
-multiple batches of processing until everything is complete.
----
- src/llama.cpp | 99 ++++++++++++++++++++++++---------------------------
- 1 file changed, 46 insertions(+), 53 deletions(-)
-
-diff --git a/src/llama.cpp b/src/llama.cpp
-index 8f7902df..01854fce 100644
---- a/src/llama.cpp
-+++ b/src/llama.cpp
-@@ -1054,6 +1054,13 @@ static struct ggml_tensor * llm_build_rwkv6_channel_mix(
-     return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
- }
- 
-+// block of KV slots to move when defragging
-+struct llama_kv_defrag_move {
-+    uint32_t src;
-+    uint32_t dst;
-+    uint32_t len;
-+};
-+
- struct llm_build_context {
-     const llama_model    & model;
-           llama_context  & lctx;
-@@ -1230,35 +1237,23 @@ struct llm_build_context {
-         return gf;
-     }
- 
--    struct ggml_cgraph * build_defrag(const std::vector & ids) {
-+    struct ggml_cgraph * build_defrag(const std::vector & moves) {
-         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
- 
--        for (uint32_t i = 0; i < ids.size(); ++i) {
--            const uint32_t id = ids[i];
--
--            if (i == id || id == ids.size()) {
--                continue;
--            }
--
--            uint32_t nm = 1;
--
--            while (i + nm < ids.size() && ids[i + nm] == id + nm) {
--                nm++;
--            }
--
-+        for (const auto & move : moves) {
-             for (int il = 0; il < n_layer; ++il) {
-                 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-                 const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
- 
-                 ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
--                        n_embd_k_gqa, nm,
-+                        n_embd_k_gqa, move.len,
-                         ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
--                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
-+                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
- 
-                 ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
--                        n_embd_k_gqa, nm,
-+                        n_embd_k_gqa, move.len,
-                         ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
--                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
-+                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
- 
-                 ggml_tensor * view_v_src;
-                 ggml_tensor * view_v_dst;
-@@ -1266,31 +1261,29 @@ struct llm_build_context {
-                 if (flash_attn) {
-                     // NOTE: the V cache is not transposed when using flash attention
-                     view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
--                            n_embd_v_gqa, nm,
-+                            n_embd_v_gqa, move.len,
-                             ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
--                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
-+                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
- 
-                     view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
--                            n_embd_v_gqa, nm,
-+                            n_embd_v_gqa, move.len,
-                             ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
--                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
-+                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
-                 } else {
-                     view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
--                            nm, n_embd_v_gqa,
-+                            move.len, n_embd_v_gqa,
-                             ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
--                            ggml_row_size(kv_self.v_l[il]->type, i));
-+                            ggml_row_size(kv_self.v_l[il]->type, move.src));
- 
-                     view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
--                            nm, n_embd_v_gqa,
-+                            move.len, n_embd_v_gqa,
-                             ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
--                            ggml_row_size(kv_self.v_l[il]->type, id));
-+                            ggml_row_size(kv_self.v_l[il]->type, move.dst));
-                 }
- 
-                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
-                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
-             }
--
--            i += nm - 1;
-         }
- 
-         //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
-@@ -8508,7 +8501,7 @@ struct llm_build_context {
-     }
- };
- 
--static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) {
-+static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & moves) {
-     llama_ubatch dummy = {};
-     dummy.equal_seqs = true;
- 
-@@ -8518,7 +8511,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
- 
-     llm.init();
- 
--    struct ggml_cgraph * result = llm.build_defrag(ids);
-+    struct ggml_cgraph * result = llm.build_defrag(moves);
- 
-     llm.free();
- 
-@@ -8956,7 +8949,12 @@ static int llama_prepare_ubatch(
-             kv_self.head = 0;
-         }
- 
--        const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-+        auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-+        if (!slot) {
-+            llama_kv_cache_defrag(kv_self);
-+            llama_kv_cache_update(&lctx);
-+            slot = llama_kv_cache_find_slot(kv_self, ubatch);
-+        }
-         if (!slot) {
-             return 1;
-         }
-@@ -9431,8 +9429,8 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
- 
-     //const int64_t t_start = ggml_time_us();
- 
--    // number of cells moved
--    uint32_t n_moves = 0;
-+    // groups of cells moved
-+    std::vector moves;
- 
-     // each move requires 6*n_layer tensors (see build_defrag)
-     //   - source view, destination view, copy operation
-@@ -9496,19 +9494,11 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
-         // are we moving a continuous block of memory?
-         bool cont = false;
- 
--        // should we stop searching for the next move?
--        bool stop = false;
--
-         // go back and move the nf cells to the hole
-         for (; i1 < n_kv; ++i1) {
-             auto & cell1 = kv_self.cells[i1];
- 
-             if (cell1.is_empty() || ids[i1] != n_kv) {
--                if (n_moves == max_moves) {
--                    stop = true;
--                    break;
--                }
--
-                 cont = false;
-                 continue;
-             }
-@@ -9524,8 +9514,10 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
-             kv_self.head = n_used;
- 
-             if (!cont) {
--                n_moves++;
-+                moves.push_back({i1, i0 + nf, 1});
-                 cont = true;
-+            } else {
-+                moves.back().len++;
-             }
- 
-             nf++;
-@@ -9535,22 +9527,16 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
-             }
-         }
- 
--        if (stop || n_moves == max_moves) {
--            break;
--        }
--
-         //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
- 
-         i0 += nh - 1;
-     }
- 
--    if (n_moves == 0) {
-+    if (moves.size() == 0) {
-         return;
-     }
- 
--    //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
--
--    //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
-+    //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n",  moves.size());
- 
- #if 0
-     // CPU defrag
-@@ -9625,11 +9611,18 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
- #else
-     // ggml_graph defrag
- 
--    ggml_backend_sched_reset(lctx.sched.get());
-+    for (std::size_t i = 0; i < moves.size(); i += max_moves) {
-+        std::vector chunk;
-+        auto end = std::min(i + max_moves, moves.size());
-+        chunk.assign(moves.begin() + i, moves.begin() + end);
- 
--    ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
-+        ggml_backend_sched_reset(lctx.sched.get());
-+
-+        //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
-+        ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
- 
--    llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
-+        llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
-+    }
- #endif
- 
-     //const int64_t t_end = ggml_time_us();
diff --git a/llama/patches/0013-sort-devices-by-score.patch b/llama/patches/0012-sort-devices-by-score.patch
similarity index 79%
rename from llama/patches/0013-sort-devices-by-score.patch
rename to llama/patches/0012-sort-devices-by-score.patch
index 7640a8db0..8c3908cf6 100644
--- a/llama/patches/0013-sort-devices-by-score.patch
+++ b/llama/patches/0012-sort-devices-by-score.patch
@@ -1,17 +1,20 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Michael Yang 
-Date: Tue, 14 Jan 2025 12:01:24 -0800
+From: jmorganca 
+Date: Tue, 8 Apr 2025 20:31:38 -0700
 Subject: [PATCH] sort devices by score
 
+in the ggml backend loading code, devices
+are now sorted by score, ensuring the device
+with the fastest acceleration is loaded
 ---
  ggml/src/ggml-backend-reg.cpp | 21 +++++++++++++--------
  1 file changed, 13 insertions(+), 8 deletions(-)
 
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 95036ef8..98d5e14d 100644
+index 82ae1b5b..1487f322 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
-@@ -150,7 +150,7 @@ struct ggml_backend_reg_entry {
+@@ -157,7 +157,7 @@ struct ggml_backend_reg_entry {
  
  struct ggml_backend_registry {
      std::vector backends;
@@ -20,7 +23,7 @@ index 95036ef8..98d5e14d 100644
  
      ggml_backend_registry() {
  #ifdef GGML_USE_CUDA
-@@ -195,7 +195,7 @@ struct ggml_backend_registry {
+@@ -202,7 +202,7 @@ struct ggml_backend_registry {
          }
      }
  
@@ -29,7 +32,7 @@ index 95036ef8..98d5e14d 100644
          if (!reg) {
              return;
          }
-@@ -206,15 +206,20 @@ struct ggml_backend_registry {
+@@ -213,15 +213,20 @@ struct ggml_backend_registry {
  #endif
          backends.push_back({ reg, std::move(handle) });
          for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
@@ -52,17 +55,17 @@ index 95036ef8..98d5e14d 100644
 +        );
      }
  
-     ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
-@@ -257,7 +262,7 @@ struct ggml_backend_registry {
+     ggml_backend_reg_t load_backend(const fs::path & path, bool silent) {
+@@ -265,7 +270,7 @@ struct ggml_backend_registry {
  
-         GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
+         GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str());
  
 -        register_backend(reg, std::move(handle));
 +        register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
  
          return reg;
      }
-@@ -280,7 +285,7 @@ struct ggml_backend_registry {
+@@ -288,7 +293,7 @@ struct ggml_backend_registry {
          // remove devices
          devices.erase(
              std::remove_if(devices.begin(), devices.end(),
@@ -71,7 +74,7 @@ index 95036ef8..98d5e14d 100644
              devices.end());
  
          // remove backend
-@@ -338,7 +343,7 @@ size_t ggml_backend_dev_count() {
+@@ -346,7 +351,7 @@ size_t ggml_backend_dev_count() {
  
  ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
      GGML_ASSERT(index < ggml_backend_dev_count());
diff --git a/llama/patches/0012-use-dynamic-backend-loading-for-clip.patch b/llama/patches/0012-use-dynamic-backend-loading-for-clip.patch
deleted file mode 100644
index e28571695..000000000
--- a/llama/patches/0012-use-dynamic-backend-loading-for-clip.patch
+++ /dev/null
@@ -1,102 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Sat, 4 Jan 2025 22:52:48 -0800
-Subject: [PATCH] use dynamic backend loading for clip
-
----
- examples/llava/clip.cpp | 74 +++++++++++++++--------------------------
- 1 file changed, 27 insertions(+), 47 deletions(-)
-
-diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
-index 205af1eb..560021c7 100644
---- a/examples/llava/clip.cpp
-+++ b/examples/llava/clip.cpp
-@@ -9,25 +9,25 @@
- #include "ggml-backend.h"
- #include "gguf.h"
- 
--//#ifdef GGML_USE_CUDA
--//#include "ggml-cuda.h"
--//#endif
--//
--//#ifdef GGML_USE_SYCL
--//#include "ggml-sycl.h"
--//#endif
--//
--//#ifdef GGML_USE_METAL
--//#include "ggml-metal.h"
--//#endif
--//
--//#ifdef GGML_USE_CANN
--//#include "ggml-cann.h"
--//#endif
--//
--//#ifdef GGML_USE_VULKAN
--//#include "ggml-vulkan.h"
--//#endif
-+#ifdef GGML_USE_CUDA
-+#include "ggml-cuda.h"
-+#endif
-+
-+#ifdef GGML_USE_SYCL
-+#include "ggml-sycl.h"
-+#endif
-+
-+#ifdef GGML_USE_METAL
-+#include "ggml-metal.h"
-+#endif
-+
-+#ifdef GGML_USE_CANN
-+#include "ggml-cann.h"
-+#endif
-+
-+#ifdef GGML_USE_VULKAN
-+#include "ggml-vulkan.h"
-+#endif
- 
- #define STB_IMAGE_IMPLEMENTATION
- #include "stb_image.h"
-@@ -1309,35 +1309,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
-         }
-     }
- 
--//#ifdef GGML_USE_CUDA
--//    new_clip->backend = ggml_backend_cuda_init(0);
--//    LOG_INF("%s: CLIP using CUDA backend\n", __func__);
--//#endif
--//
--//#ifdef GGML_USE_METAL
--//    new_clip->backend = ggml_backend_metal_init();
--//    LOG_INF("%s: CLIP using Metal backend\n", __func__);
--//#endif
--//
--//#ifdef GGML_USE_CANN
--//    new_clip->backend = ggml_backend_cann_init(0);
--//    LOG_INF("%s: CLIP using CANN backend\n", __func__);
--//#endif
--//
--//#ifdef GGML_USE_VULKAN
--//    new_clip->backend = ggml_backend_vk_init(0);
--//    LOG_INF("%s: CLIP using Vulkan backend\n", __func__);
--//#endif
--//
--//#ifdef GGML_USE_SYCL
--//    new_clip->backend = ggml_backend_sycl_init(0);
--//    LOG_INF("%s: CLIP using SYCL backend\n", __func__);
--//#endif
--
--    if (!new_clip->backend) {
--        new_clip->backend = ggml_backend_cpu_init();
--        LOG_INF("%s: CLIP using CPU backend\n", __func__);
-+    ggml_backend_t backend = ggml_backend_init_best();
-+    if (backend == nullptr) {
-+        LOG_ERR("%s: failed to initialize backend\n", __func__);
-+        clip_free(new_clip);
-+        gguf_free(ctx);
-+        return nullptr;
-     }
-+    LOG_INF("%s: using %s backend\n", __func__, ggml_backend_name(backend));
-+    new_clip->backend = backend;
- 
-     // model size and capabilities
-     {
diff --git a/llama/patches/0014-add-phony-target-ggml-cpu-for-all-cpu-variants.patch b/llama/patches/0013-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
similarity index 65%
rename from llama/patches/0014-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
rename to llama/patches/0013-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
index f263ece17..eaba3c4c8 100644
--- a/llama/patches/0014-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
+++ b/llama/patches/0013-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
@@ -1,6 +1,6 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Michael Yang 
-Date: Tue, 14 Jan 2025 15:59:04 -0800
+From: jmorganca 
+Date: Tue, 8 Apr 2025 20:32:07 -0700
 Subject: [PATCH] add phony target ggml-cpu for all cpu variants
 
 ---
@@ -8,10 +8,10 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants
  1 file changed, 2 insertions(+)
 
 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
-index 0002ac18..0a8d1092 100644
+index f00700da..91d6a7d5 100644
 --- a/ggml/src/CMakeLists.txt
 +++ b/ggml/src/CMakeLists.txt
-@@ -297,6 +297,7 @@ function(ggml_add_cpu_backend_variant tag_name)
+@@ -278,6 +278,7 @@ function(ggml_add_cpu_backend_variant tag_name)
      endforeach()
  
      ggml_add_cpu_backend_variant_impl(${tag_name})
@@ -19,11 +19,11 @@ index 0002ac18..0a8d1092 100644
  endfunction()
  
  ggml_add_backend(CPU)
-@@ -305,6 +306,7 @@ if (GGML_CPU_ALL_VARIANTS)
+@@ -286,6 +287,7 @@ if (GGML_CPU_ALL_VARIANTS)
      if (NOT GGML_BACKEND_DL)
          message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
      endif()
 +    add_custom_target(ggml-cpu)
      ggml_add_cpu_backend_variant(sandybridge    AVX)
-     ggml_add_cpu_backend_variant(haswell        AVX F16C AVX2 FMA)
-     ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 FMA AVX512)
+     ggml_add_cpu_backend_variant(haswell        AVX F16C AVX2 BMI2 FMA)
+     ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 BMI2 FMA AVX512)
diff --git a/llama/patches/0014-remove-amx.patch b/llama/patches/0014-remove-amx.patch
new file mode 100644
index 000000000..0bbc0a3a3
--- /dev/null
+++ b/llama/patches/0014-remove-amx.patch
@@ -0,0 +1,25 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca 
+Date: Tue, 8 Apr 2025 20:33:01 -0700
+Subject: [PATCH] remove amx
+
+disable amx as it reduces performance on some systems
+---
+ ggml/src/CMakeLists.txt | 4 ----
+ 1 file changed, 4 deletions(-)
+
+diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
+index 91d6a7d5..d6b393a2 100644
+--- a/ggml/src/CMakeLists.txt
++++ b/ggml/src/CMakeLists.txt
+@@ -293,10 +293,6 @@ if (GGML_CPU_ALL_VARIANTS)
+     ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 BMI2 FMA AVX512)
+     ggml_add_cpu_backend_variant(icelake        AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
+     ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 BMI2 FMA AVX_VNNI)
+-    if (NOT MSVC)
+-        # MSVC doesn't support AMX
+-        ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
+-    endif()
+ elseif (GGML_CPU)
+     ggml_add_cpu_backend_variant_impl("")
+ endif()
diff --git a/llama/patches/0019-fix-string-arr-kv-loading.patch b/llama/patches/0015-fix-string-arr-kv-loading.patch
similarity index 91%
rename from llama/patches/0019-fix-string-arr-kv-loading.patch
rename to llama/patches/0015-fix-string-arr-kv-loading.patch
index aa7b4d3cf..ae7e4ca9f 100644
--- a/llama/patches/0019-fix-string-arr-kv-loading.patch
+++ b/llama/patches/0015-fix-string-arr-kv-loading.patch
@@ -1,8 +1,11 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
 From: jmorganca 
-Date: Wed, 5 Mar 2025 17:41:07 -0800
+Date: Tue, 8 Apr 2025 20:35:53 -0700
 Subject: [PATCH] fix string arr kv loading
 
+certain models would error when loading
+kv metadata fields that contain an array of strings
+such as vocab fields
 ---
  ggml/include/gguf.h | 1 +
  ggml/src/gguf.cpp   | 7 +++++--
@@ -22,7 +25,7 @@ index 79ee2020..3efb22f0 100644
      // get ith C string from array with given key_id
      GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
 diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
-index ab13669c..f75b923f 100644
+index 381a9c7d..e45b453d 100644
 --- a/ggml/src/gguf.cpp
 +++ b/ggml/src/gguf.cpp
 @@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
@@ -50,10 +53,10 @@ index ab13669c..f75b923f 100644
  }
  
 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index c7ff28be..7a185443 100644
+index d74919d2..c90f636c 100644
 --- a/src/llama-vocab.cpp
 +++ b/src/llama-vocab.cpp
-@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+@@ -1459,7 +1459,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
  
              const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
              if (precompiled_charsmap_keyidx != -1) {
diff --git a/llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch b/llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch
deleted file mode 100644
index e72d78ac8..000000000
--- a/llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch
+++ /dev/null
@@ -1,369 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Sun, 16 Feb 2025 20:00:22 -0500
-Subject: [PATCH] use std::filesystem::path instead of wstring
-
----
- ggml/src/ggml-backend-reg.cpp | 199 +++++++++++++++-------------------
- 1 file changed, 88 insertions(+), 111 deletions(-)
-
-diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 98d5e14d..799af5f3 100644
---- a/ggml/src/ggml-backend-reg.cpp
-+++ b/ggml/src/ggml-backend-reg.cpp
-@@ -66,26 +66,6 @@
- #include "ggml-kompute.h"
- #endif
- 
--// disable C++17 deprecation warning for std::codecvt_utf8
--#if defined(__clang__)
--#    pragma clang diagnostic push
--#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
--#endif
--
--static std::wstring utf8_to_utf16(const std::string & str) {
--    std::wstring_convert> converter;
--    return converter.from_bytes(str);
--}
--
--static std::string utf16_to_utf8(const std::wstring & str) {
--    std::wstring_convert> converter;
--    return converter.to_bytes(str);
--}
--
--#if defined(__clang__)
--#    pragma clang diagnostic pop
--#endif
--
- #ifdef _WIN32
- 
- using dl_handle = std::remove_pointer_t;
-@@ -96,7 +76,7 @@ struct dl_handle_deleter {
-     }
- };
- 
--static dl_handle * dl_load_library(const std::wstring & path) {
-+static dl_handle * dl_load_library(const std::filesystem::path & path) {
-     // suppress error dialogs for missing DLLs
-     DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
-     SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
-@@ -129,8 +109,8 @@ struct dl_handle_deleter {
-     }
- };
- 
--static void * dl_load_library(const std::wstring & path) {
--    dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
-+static void * dl_load_library(const std::filesystem::path & path) {
-+    dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
- 
-     return handle;
- }
-@@ -141,6 +121,25 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
- 
- #endif
- 
-+static std::string path_to_string(const std::filesystem::path & path)
-+{
-+#ifdef _WIN32
-+    const std::wstring wstr = path.wstring();
-+    const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr);
-+    if (size_needed <= 0) {
-+        return std::string();
-+    }
-+
-+    // size_needed includes the null terminator
-+    std::string str(size_needed - 1, '\0');
-+    WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr);
-+    return str;
-+#else
-+    return path.string();
-+#endif
-+}
-+
-+
- using dl_handle_ptr = std::unique_ptr;
- 
- struct ggml_backend_reg_entry {
-@@ -222,11 +221,11 @@ struct ggml_backend_registry {
-         );
-     }
- 
--    ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
-+    ggml_backend_reg_t load_backend(const std::filesystem::path & path, bool silent) {
-         dl_handle_ptr handle { dl_load_library(path) };
-         if (!handle) {
-             if (!silent) {
--                GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
-+                GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str());
-             }
-             return nullptr;
-         }
-@@ -234,7 +233,7 @@ struct ggml_backend_registry {
-         auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
-         if (score_fn && score_fn() == 0) {
-             if (!silent) {
--                GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
-+                GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str());
-             }
-             return nullptr;
-         }
-@@ -242,7 +241,7 @@ struct ggml_backend_registry {
-         auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
-         if (!backend_init_fn) {
-             if (!silent) {
--                GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
-+                GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str());
-             }
-             return nullptr;
-         }
-@@ -251,16 +250,16 @@ struct ggml_backend_registry {
-         if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
-             if (!silent) {
-                 if (!reg) {
--                    GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
-+                    GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str());
-                 } else {
-                     GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
--                        __func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
-+                        __func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
-                 }
-             }
-             return nullptr;
-         }
- 
--        GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
-+        GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str());
- 
-         register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
- 
-@@ -396,14 +395,14 @@ ggml_backend_t ggml_backend_init_best(void) {
- 
- // Dynamic loading
- ggml_backend_reg_t ggml_backend_load(const char * path) {
--    return get_reg().load_backend(utf8_to_utf16(path), false);
-+    return get_reg().load_backend(path, false);
- }
- 
- void ggml_backend_unload(ggml_backend_reg_t reg) {
-     get_reg().unload_backend(reg, true);
- }
- 
--static std::wstring get_executable_path() {
-+static std::filesystem::path get_executable_path() {
- #if defined(__APPLE__)
-     // get executable path
-     std::vector path;
-@@ -415,15 +414,9 @@ static std::wstring get_executable_path() {
-         }
-         path.resize(size);
-     }
--    std::string base_path(path.data(), size);
--    // remove executable name
--    auto last_slash = base_path.find_last_of('/');
--    if (last_slash != std::string::npos) {
--        base_path = base_path.substr(0, last_slash);
--    }
--    return utf8_to_utf16(base_path + "/");
-+
-+    return std::filesystem::path(path.data()).parent_path();
- #elif defined(__linux__) || defined(__FreeBSD__)
--    std::string base_path = ".";
-     std::vector path(1024);
-     while (true) {
-         // get executable path
-@@ -436,76 +429,55 @@ static std::wstring get_executable_path() {
-             break;
-         }
-         if (len < (ssize_t) path.size()) {
--            base_path = std::string(path.data(), len);
--            // remove executable name
--            auto last_slash = base_path.find_last_of('/');
--            if (last_slash != std::string::npos) {
--                base_path = base_path.substr(0, last_slash);
--            }
--            break;
-+            return std::filesystem::path(path.data()).parent_path();
-         }
-         path.resize(path.size() * 2);
-     }
--
--    return utf8_to_utf16(base_path + "/");
- #elif defined(_WIN32)
-     std::vector path(MAX_PATH);
-     DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
-     if (len == 0) {
-         return {};
-     }
--    std::wstring base_path(path.data(), len);
--    // remove executable name
--    auto last_slash = base_path.find_last_of('\\');
--    if (last_slash != std::string::npos) {
--        base_path = base_path.substr(0, last_slash);
--    }
--    return base_path + L"\\";
--#else
--    return {};
--#endif
--}
- 
--static std::wstring backend_filename_prefix() {
--#ifdef _WIN32
--    return L"ggml-";
--#else
--    return L"libggml-";
-+    return std::filesystem::path(path.data()).parent_path();
- #endif
-+    return {};
- }
- 
--static std::wstring backend_filename_suffix() {
-+static std::string backend_filename_prefix() {
- #ifdef _WIN32
--    return L".dll";
-+    return "ggml-";
- #else
--    return L".so";
-+    return "libggml-";
- #endif
- }
- 
--static std::wstring path_separator() {
-+static std::string backend_filename_suffix() {
- #ifdef _WIN32
--    return L"\\";
-+    return ".dll";
- #else
--    return L"/";
-+    return ".so";
- #endif
- }
- 
- static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
-     // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
-      // TODO: search system paths
--    std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-";
--    std::vector search_paths;
-+    namespace fs = std::filesystem;
-+    std::string file_prefix = backend_filename_prefix() + name + "-";
-+    std::vector search_paths;
-+
-     if (user_search_path == nullptr) {
--        search_paths.push_back(L"." + path_separator());
-+        search_paths.push_back(fs::current_path());
-         search_paths.push_back(get_executable_path());
-     } else {
--        search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator());
-+        search_paths.push_back(fs::u8path(user_search_path));
-     }
- 
-     int best_score = 0;
--    std::wstring best_path;
-+    fs::path best_path;
- 
--    namespace fs = std::filesystem;
-     for (const auto & search_path : search_paths) {
-         if (!fs::exists(search_path)) {
-             continue;
-@@ -513,29 +485,26 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
-         fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
-         for (const auto & entry : dir_it) {
-             if (entry.is_regular_file()) {
--                std::wstring filename = entry.path().filename().wstring();
--                std::wstring ext = entry.path().extension().wstring();
-+                std::string filename = entry.path().filename().string();
-+                std::string ext = entry.path().extension().string();
-                 if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
--                    dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
--                    if (!handle && !silent) {
--                        GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
-+                    dl_handle_ptr handle { dl_load_library(entry.path()) };
-+                    if (!handle) {
-+                        GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
-+                        continue;
-                     }
--                    if (handle) {
--                        auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
--                        if (score_fn) {
--                            int s = score_fn();
--#ifndef NDEBUG
--                            GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
--#endif
--                            if (s > best_score) {
--                                best_score = s;
--                                best_path = entry.path().wstring();
--                            }
--                        } else {
--                            if (!silent) {
--                                GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
--                            }
--                        }
-+
-+                    auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
-+                    if (!score_fn) {
-+                        GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
-+                        continue;
-+                    }
-+
-+                    int s = score_fn();
-+                    GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
-+                    if (s > best_score) {
-+                        best_score = s;
-+                        best_path = entry.path();
-                     }
-                 }
-             }
-@@ -545,7 +514,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
-     if (best_score == 0) {
-         // try to load the base backend
-         for (const auto & search_path : search_paths) {
--            std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix();
-+            fs::path path = fs::path(search_path) / (backend_filename_prefix() + name + backend_filename_suffix());
-             if (fs::exists(path)) {
-                 return get_reg().load_backend(path, silent);
-             }
-@@ -560,6 +529,14 @@ void ggml_backend_load_all() {
-     ggml_backend_load_all_from_path(nullptr);
- }
- 
-+static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
-+    try {
-+        ggml_backend_load_best(name, silent, user_search_path);
-+    } catch (const std::exception & e) {
-+        GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
-+    }
-+}
-+
- void ggml_backend_load_all_from_path(const char * dir_path) {
- #ifdef NDEBUG
-     bool silent = true;
-@@ -567,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
-     bool silent = false;
- #endif
- 
--    ggml_backend_load_best("blas", silent, dir_path);
--    ggml_backend_load_best("cann", silent, dir_path);
--    ggml_backend_load_best("cuda", silent, dir_path);
--    ggml_backend_load_best("hip", silent, dir_path);
--    ggml_backend_load_best("kompute", silent, dir_path);
--    ggml_backend_load_best("metal", silent, dir_path);
--    ggml_backend_load_best("rpc", silent, dir_path);
--    ggml_backend_load_best("sycl", silent, dir_path);
--    ggml_backend_load_best("vulkan", silent, dir_path);
--    ggml_backend_load_best("opencl", silent, dir_path);
--    ggml_backend_load_best("musa", silent, dir_path);
--    ggml_backend_load_best("cpu", silent, dir_path);
-+    ggml_backend_try_load_best("blas", silent, dir_path);
-+    ggml_backend_try_load_best("cann", silent, dir_path);
-+    ggml_backend_try_load_best("cuda", silent, dir_path);
-+    ggml_backend_try_load_best("hip", silent, dir_path);
-+    ggml_backend_try_load_best("kompute", silent, dir_path);
-+    ggml_backend_try_load_best("metal", silent, dir_path);
-+    ggml_backend_try_load_best("rpc", silent, dir_path);
-+    ggml_backend_try_load_best("sycl", silent, dir_path);
-+    ggml_backend_try_load_best("vulkan", silent, dir_path);
-+    ggml_backend_try_load_best("opencl", silent, dir_path);
-+    ggml_backend_try_load_best("musa", silent, dir_path);
-+    ggml_backend_try_load_best("cpu", silent, dir_path);
-     // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
-     const char * backend_path = std::getenv("GGML_BACKEND_PATH");
-     if (backend_path) {
diff --git a/llama/patches/0020-ollama-debug-tensor.patch b/llama/patches/0016-ollama-debug-tensor.patch
similarity index 79%
rename from llama/patches/0020-ollama-debug-tensor.patch
rename to llama/patches/0016-ollama-debug-tensor.patch
index b9f2e4ab0..a192bdea8 100644
--- a/llama/patches/0020-ollama-debug-tensor.patch
+++ b/llama/patches/0016-ollama-debug-tensor.patch
@@ -1,6 +1,6 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Michael Yang 
-Date: Sun, 9 Mar 2025 14:44:16 -0700
+From: jmorganca 
+Date: Tue, 8 Apr 2025 20:36:41 -0700
 Subject: [PATCH] ollama debug tensor
 
 ---
@@ -8,11 +8,11 @@ Subject: [PATCH] ollama debug tensor
  1 file changed, 6 insertions(+)
 
 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
-index 2f606d82..ec60e8fc 100644
+index 432942bf..6d4abe4c 100644
 --- a/ggml/src/ggml-cpu/ggml-cpu.c
 +++ b/ggml/src/ggml-cpu/ggml-cpu.c
-@@ -11,6 +11,8 @@
- #include "ggml-threading.h"
+@@ -15,6 +15,8 @@
+ #include "ops.h"
  #include "ggml.h"
  
 +#include "ollama-debug.h"
@@ -20,7 +20,7 @@ index 2f606d82..ec60e8fc 100644
  #if defined(_MSC_VER) || defined(__MINGW32__)
  #include  // using malloc.h with MSC/MINGW
  #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
-@@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
+@@ -2854,6 +2856,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
  
          ggml_compute_forward(¶ms, node);
  
diff --git a/llama/patches/0016-remove-amx.patch b/llama/patches/0016-remove-amx.patch
deleted file mode 100644
index 234d51cc7..000000000
--- a/llama/patches/0016-remove-amx.patch
+++ /dev/null
@@ -1,24 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Michael Yang 
-Date: Tue, 18 Feb 2025 14:47:21 -0800
-Subject: [PATCH] remove amx
-
----
- ggml/src/CMakeLists.txt | 4 ----
- 1 file changed, 4 deletions(-)
-
-diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
-index 0a8d1092..4564df91 100644
---- a/ggml/src/CMakeLists.txt
-+++ b/ggml/src/CMakeLists.txt
-@@ -312,10 +312,6 @@ if (GGML_CPU_ALL_VARIANTS)
-     ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 FMA AVX512)
-     ggml_add_cpu_backend_variant(icelake        AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
-     ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 FMA AVX_VNNI)
--    if (NOT MSVC)
--        # MSVC doesn't support AMX
--        ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
--    endif()
- elseif (GGML_CPU)
-     ggml_add_cpu_backend_variant_impl("")
- endif()
diff --git a/llama/patches/0017-add-model-quantizations.patch b/llama/patches/0017-add-model-quantizations.patch
new file mode 100644
index 000000000..1f40328c7
--- /dev/null
+++ b/llama/patches/0017-add-model-quantizations.patch
@@ -0,0 +1,96 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca 
+Date: Tue, 8 Apr 2025 20:39:32 -0700
+Subject: [PATCH] add model quantizations
+
+a temporary patch to add model quantization for
+models not supported in llama.cpp
+---
+ src/llama-arch.cpp  | 17 +++++++++++++++++
+ src/llama-arch.h    |  1 +
+ src/llama-model.cpp |  2 ++
+ src/llama-quant.cpp |  4 ++++
+ 4 files changed, 24 insertions(+)
+
+diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
+index c1f78618..bdf3d898 100644
+--- a/src/llama-arch.cpp
++++ b/src/llama-arch.cpp
+@@ -73,6 +73,7 @@ static const std::map LLM_ARCH_NAMES = {
+     { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
+     { LLM_ARCH_PLM,              "plm"              },
+     { LLM_ARCH_BAILINGMOE,       "bailingmoe"       },
++    { LLM_ARCH_MISTRAL3,         "mistral3"         },
+     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
+ };
+ 
+@@ -1582,6 +1583,22 @@ static const std::map> LLM_TENSOR_N
+             { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
+         },
+     },
++    {
++        LLM_ARCH_MISTRAL3,
++        {
++            { LLM_TENSOR_TOKEN_EMBD,  "token_embd" },
++            { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
++            { LLM_TENSOR_ATTN_NORM,   "blk.%d.attn_norm" },
++            { LLM_TENSOR_ATTN_Q,      "blk.%d.attn_q" },
++            { LLM_TENSOR_ATTN_K,      "blk.%d.attn_k" },
++            { LLM_TENSOR_ATTN_V,      "blk.%d.attn_v" },
++            { LLM_TENSOR_ATTN_OUT,    "blk.%d.attn_output" },
++            { LLM_TENSOR_FFN_NORM,    "blk.%d.ffn_norm" },
++            { LLM_TENSOR_FFN_GATE,    "blk.%d.ffn_gate" },
++            { LLM_TENSOR_FFN_UP,      "blk.%d.ffn_up" },
++            { LLM_TENSOR_FFN_DOWN,    "blk.%d.ffn_down" },
++        }
++    },
+     {
+         LLM_ARCH_UNKNOWN,
+         {
+diff --git a/src/llama-arch.h b/src/llama-arch.h
+index f987844d..ee081fbf 100644
+--- a/src/llama-arch.h
++++ b/src/llama-arch.h
+@@ -75,6 +75,7 @@ enum llm_arch {
+     LLM_ARCH_CHAMELEON,
+     LLM_ARCH_SOLAR,
+     LLM_ARCH_WAVTOKENIZER_DEC,
++    LLM_ARCH_MISTRAL3,
+     LLM_ARCH_PLM,
+     LLM_ARCH_BAILINGMOE,
+     LLM_ARCH_UNKNOWN,
+diff --git a/src/llama-model.cpp b/src/llama-model.cpp
+index d5ad466e..cd1d239c 100644
+--- a/src/llama-model.cpp
++++ b/src/llama-model.cpp
+@@ -1423,6 +1423,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+                     default: type = LLM_TYPE_UNKNOWN;
+                 }
+             } break;
++        case LLM_ARCH_MISTRAL3: break;
+         default: throw std::runtime_error("unsupported model architecture");
+     }
+ 
+@@ -13652,6 +13653,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
+         case LLM_ARCH_CHAMELEON:
+         case LLM_ARCH_SOLAR:
+         case LLM_ARCH_BAILINGMOE:
++        case LLM_ARCH_MISTRAL3:
+             return LLAMA_ROPE_TYPE_NORM;
+ 
+         // the pairs of head values are offset by n_rot/2
+diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
+index 223e1f3f..8ae6dde8 100644
+--- a/src/llama-quant.cpp
++++ b/src/llama-quant.cpp
+@@ -744,6 +744,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
+         // This used to be a regex, but  has an extreme cost to compile times.
+         bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
+ 
++        // don't quantize vision stuff
++        quantize &= name.find("v.") == std::string::npos;
++        quantize &= name.find("mm.") == std::string::npos;
++
+         // quantize only 2D and 3D tensors (experts)
+         quantize &= (ggml_n_dims(tensor) >= 2);
+ 
diff --git a/llama/patches/0017-fix-clip-compiler-error.patch b/llama/patches/0017-fix-clip-compiler-error.patch
deleted file mode 100644
index ef6e247bd..000000000
--- a/llama/patches/0017-fix-clip-compiler-error.patch
+++ /dev/null
@@ -1,36 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Tue, 25 Feb 2025 19:14:51 -0800
-Subject: [PATCH] fix-clip-compiler-error
-
----
- examples/llava/clip.cpp | 2 +-
- examples/llava/clip.h   | 2 +-
- 2 files changed, 2 insertions(+), 2 deletions(-)
-
-diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
-index 560021c7..54265beb 100644
---- a/examples/llava/clip.cpp
-+++ b/examples/llava/clip.cpp
-@@ -1788,7 +1788,7 @@ void clip_image_f32_batch_free(struct clip_image_f32_batch  * batch) {
-     }
- }
- 
--void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
-+void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img) {
-     img->nx = nx;
-     img->ny = ny;
-     img->buf.resize(3 * nx * ny);
-diff --git a/examples/llava/clip.h b/examples/llava/clip.h
-index ce6f6194..f9f80d7d 100644
---- a/examples/llava/clip.h
-+++ b/examples/llava/clip.h
-@@ -75,7 +75,7 @@ CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch  * batch);
- CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
- 
- /** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */
--CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img);
-+CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
- 
- CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
- 
diff --git a/llama/patches/0022-metal-add-op_neg.patch b/llama/patches/0018-add-op_neg.patch
similarity index 72%
rename from llama/patches/0022-metal-add-op_neg.patch
rename to llama/patches/0018-add-op_neg.patch
index a903535f2..b9cab0e34 100644
--- a/llama/patches/0022-metal-add-op_neg.patch
+++ b/llama/patches/0018-add-op_neg.patch
@@ -1,18 +1,19 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Michael Yang 
-Date: Wed, 2 Apr 2025 15:26:15 -0700
-Subject: [PATCH] metal: add op_neg
+From: jmorganca 
+Date: Tue, 8 Apr 2025 20:41:24 -0700
+Subject: [PATCH] add op_neg
 
+adds the neg operator to ggml
 ---
  ggml/src/ggml-metal/ggml-metal.m     | 15 +++++++++++++++
  ggml/src/ggml-metal/ggml-metal.metal |  7 +++++++
  2 files changed, 22 insertions(+)
 
 diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
-index e4c093f9..d8422f1b 100644
+index b121ab9e..fea50521 100644
 --- a/ggml/src/ggml-metal/ggml-metal.m
 +++ b/ggml/src/ggml-metal/ggml-metal.m
-@@ -423,6 +423,7 @@ enum ggml_metal_kernel_type {
+@@ -461,6 +461,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
      GGML_METAL_KERNEL_TYPE_SQRT,
      GGML_METAL_KERNEL_TYPE_SIN,
      GGML_METAL_KERNEL_TYPE_COS,
@@ -20,23 +21,23 @@ index e4c093f9..d8422f1b 100644
      GGML_METAL_KERNEL_TYPE_SUM_ROWS,
      GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
      GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
-@@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
-+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                           neg,                            true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                        argmax,                         true);
-         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,               pool_2d_avg_f32,                true);
-@@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
+@@ -1119,6 +1120,7 @@ @implementation GGMLMetalClass
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                            sqrt,                            true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                             sin,                             true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                             cos,                             true);
++        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                             neg,                             true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
+@@ -1280,6 +1282,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                  case GGML_UNARY_OP_GELU_QUICK:
                  case GGML_UNARY_OP_SILU:
                  case GGML_UNARY_OP_ELU:
 +                case GGML_UNARY_OP_NEG:
-                     return ggml_is_contiguous(op->src[0]);
+                     return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                  default:
                      return false;
-@@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node(
+@@ -1966,6 +1969,18 @@ static void ggml_metal_encode_node(
  
                      [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                  } break;
@@ -56,10 +57,10 @@ index e4c093f9..d8422f1b 100644
                  {
                      GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index f38909d0..bb0ff668 100644
+index e3185e5b..ede9d1e6 100644
 --- a/ggml/src/ggml-metal/ggml-metal.metal
 +++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -945,6 +945,13 @@ kernel void kernel_cos(
+@@ -949,6 +949,13 @@ kernel void kernel_cos(
      dst[tpig] = cos(src0[tpig]);
  }
  
diff --git a/llama/patches/0018-add-phi4-support.patch b/llama/patches/0018-add-phi4-support.patch
deleted file mode 100644
index 1cdc81715..000000000
--- a/llama/patches/0018-add-phi4-support.patch
+++ /dev/null
@@ -1,80 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Thu, 27 Feb 2025 15:12:26 -0800
-Subject: [PATCH] add phi4 support
-
----
- include/llama.h     |  1 +
- src/llama-model.cpp | 10 +++++++---
- src/llama-vocab.cpp | 11 +++++++++++
- 3 files changed, 19 insertions(+), 3 deletions(-)
-
-diff --git a/include/llama.h b/include/llama.h
-index cc948005..16774711 100644
---- a/include/llama.h
-+++ b/include/llama.h
-@@ -105,6 +105,7 @@ extern "C" {
-         LLAMA_VOCAB_PRE_TYPE_CHAMELEON      = 26,
-         LLAMA_VOCAB_PRE_TYPE_MINERVA        = 27,
-         LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM  = 28,
-+        LLAMA_VOCAB_PRE_TYPE_GPT4O          = 29,
-     };
- 
-     enum llama_rope_type {
-diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index 21819080..ab1a07d1 100644
---- a/src/llama-model.cpp
-+++ b/src/llama-model.cpp
-@@ -2283,7 +2283,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
- 
-                     // output
-                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
--                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
-+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
-+                    // if output is NULL, init from the input tok embed
-+                    if (output == NULL) {
-+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
-+                    }
- 
-                     for (int i = 0; i < n_layer; ++i) {
-                         auto & layer = layers[i];
-@@ -2298,8 +2302,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
-                         layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
-                         layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
- 
--                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
--                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
-+                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
-+                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
-                     }
-                 } break;
-             case LLM_ARCH_PHIMOE:
-diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index 1ca827eb..c7ff28be 100644
---- a/src/llama-vocab.cpp
-+++ b/src/llama-vocab.cpp
-@@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
-                     "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                 };
-                 break;
-+            case LLAMA_VOCAB_PRE_TYPE_GPT4O:
-+                // original regex from tokenizer.json
-+                // [^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+
-+                regex_exprs = {
-+                    "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
-+                };
-+                break;
-             default:
-                 // default regex for BPE tokenization pre-processing
-                 regex_exprs = {
-@@ -1583,6 +1590,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
-             } else if (
-                 tokenizer_pre == "megrez") {
-                 pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
-+            } else if (
-+                tokenizer_pre == "gpt-4o") {
-+                pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
-+                clean_spaces = false;
-             } else {
-                 LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
-                 pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
diff --git a/llama/patches/0019-fix-compiler-error-in-clip.h.patch b/llama/patches/0019-fix-compiler-error-in-clip.h.patch
new file mode 100644
index 000000000..a0b90df18
--- /dev/null
+++ b/llama/patches/0019-fix-compiler-error-in-clip.h.patch
@@ -0,0 +1,39 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca 
+Date: Tue, 8 Apr 2025 20:49:50 -0700
+Subject: [PATCH] fix compiler error in clip.h
+
+fixes an error that occurs in clip.h when compiling
+using CGo
+---
+ examples/llava/clip.h | 5 +++--
+ 1 file changed, 3 insertions(+), 2 deletions(-)
+
+diff --git a/examples/llava/clip.h b/examples/llava/clip.h
+index cc133a58..5fc45d3e 100644
+--- a/examples/llava/clip.h
++++ b/examples/llava/clip.h
+@@ -30,12 +30,13 @@ struct clip_image_size {
+     int height;
+ };
+ 
++struct clip_image_f32;
+ struct clip_image_u8_batch;
+ struct clip_image_f32_batch;
+ 
+ struct clip_context_params {
+     bool use_gpu;
+-    ggml_log_level verbosity;
++    enum ggml_log_level verbosity;
+ };
+ 
+ // deprecated, use clip_init
+@@ -84,7 +85,7 @@ CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
+ CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
+ CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
+ CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
+-CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
++CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
+ 
+ /**
+  * Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
diff --git a/llama/patches/0020-Revert-Simplify-and-improve-CUDA-graphs-through-use-.patch b/llama/patches/0020-Revert-Simplify-and-improve-CUDA-graphs-through-use-.patch
new file mode 100644
index 000000000..a9d41d2a3
--- /dev/null
+++ b/llama/patches/0020-Revert-Simplify-and-improve-CUDA-graphs-through-use-.patch
@@ -0,0 +1,600 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca 
+Date: Sat, 12 Apr 2025 13:06:57 -0700
+Subject: [PATCH] Revert "Simplify and improve CUDA graphs through use of
+ indirect copy pointers (#9017)"
+
+this commit in llama.cpp causes errors when running llama 3.2
+vision - temporarily revert it
+
+This reverts commit 3f9da22c2b21a2cef216de50006436ef1cab8764.
+---
+ ggml/src/ggml-cuda/common.cuh   |   8 +-
+ ggml/src/ggml-cuda/cpy.cu       | 149 ++++++++++++--------------------
+ ggml/src/ggml-cuda/cpy.cuh      |   2 -
+ ggml/src/ggml-cuda/ggml-cuda.cu |  93 +++++++++++++++-----
+ 4 files changed, 124 insertions(+), 128 deletions(-)
+
+diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
+index 8284a001..a718b6a1 100644
+--- a/ggml/src/ggml-cuda/common.cuh
++++ b/ggml/src/ggml-cuda/common.cuh
+@@ -729,13 +729,7 @@ struct ggml_cuda_graph {
+     bool disable_due_to_failed_graph_capture = false;
+     int number_consecutive_updates = 0;
+     std::vector ggml_graph_properties;
+-    bool use_cpy_indirection = false;
+-    std::vector cpy_dest_ptrs;
+-    char ** dest_ptrs_d;
+-    int dest_ptrs_size = 0;
+-    // Index to allow each cpy kernel to be aware of it's position within the graph
+-    // relative to other cpy nodes.
+-    int graph_cpynode_index = -1;
++    std::vector updated_kernel_arg;
+ #endif
+ };
+ 
+diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
+index 4f4faa3e..8396df28 100644
+--- a/ggml/src/ggml-cuda/cpy.cu
++++ b/ggml/src/ggml-cuda/cpy.cu
+@@ -39,18 +39,16 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
+ }
+ 
+ template 
+-static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
++static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+                                    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+-                                   const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
++                                   const int nb12, const int nb13) {
+     const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+ 
+     if (i >= ne) {
+         return;
+     }
+ 
+-    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
+-
+     // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+     // then combine those indices with the corresponding byte offsets to get the total offsets
+     const int64_t i03 = i/(ne00 * ne01 * ne02);
+@@ -297,18 +295,16 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
+ }
+ 
+ template 
+-static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
++static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
+                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+-                                 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
++                                 const int nb12, const int nb13) {
+     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+ 
+     if (i >= ne) {
+         return;
+     }
+ 
+-    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
+-
+     const int i03 = i/(ne00 * ne01 * ne02);
+     const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+     const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+@@ -325,18 +321,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
+ }
+ 
+ template 
+-static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
++static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
+                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+-                                 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
++                                 const int nb12, const int nb13) {
+     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+ 
+     if (i >= ne) {
+         return;
+     }
+ 
+-    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
+-
+     const int i03 = i/(ne00 * ne01 * ne02);
+     const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+     const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+@@ -352,97 +346,76 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
+     cpy_blck(cx + x_offset, cdst + dst_offset);
+ }
+ 
+-// Copy destination pointers to GPU to be available when pointer indirection is in use
+-
+-void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
+-#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+-    if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
+-        CUDA_CHECK(cudaStreamSynchronize(stream));
+-        if (cuda_graph->dest_ptrs_d != nullptr) {
+-            CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
+-        }
+-        CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
+-        cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
+-    }
+-    // copy destination pointers to GPU
+-    CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
+-    cuda_graph->graph_cpynode_index = 0; // reset index
+-#else
+-    GGML_UNUSED(cuda_graph); GGML_UNUSED(host_dest_ptrs);
+-    GGML_UNUSED(host_dest_ptrs_size); GGML_UNUSED(stream);
+-#endif
+-}
+-
+ static void ggml_cpy_f16_f32_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+     cpy_f32_f16<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_f32_f32_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+     cpy_f32_f16<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_f32_bf16_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+     cpy_f32_f16<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_f32_f16_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+     cpy_f32_f16<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_f32_q8_0_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     GGML_ASSERT(ne % QK8_0 == 0);
+     const int num_blocks = ne / QK8_0;
+     cpy_f32_q<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_q8_0_f32_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     const int num_blocks = ne;
+     cpy_q_f32<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_f32_q4_0_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     GGML_ASSERT(ne % QK4_0 == 0);
+     const int num_blocks = ne / QK4_0;
+     cpy_f32_q<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_q4_0_f32_cuda(
+@@ -451,22 +424,22 @@ static void ggml_cpy_q4_0_f32_cuda(
+     const int nb00, const int nb01, const int nb02,
+     const int nb03, const int ne10, const int ne11, const int ne12,
+     const int nb10, const int nb11, const int nb12, const int nb13,
+-    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    cudaStream_t stream) {
+     const int num_blocks = ne;
+     cpy_q_f32, QK4_0><<>>(
+         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+-         ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_f32_q4_1_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     GGML_ASSERT(ne % QK4_1 == 0);
+     const int num_blocks = ne / QK4_1;
+     cpy_f32_q<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_q4_1_f32_cuda(
+@@ -475,22 +448,22 @@ static void ggml_cpy_q4_1_f32_cuda(
+     const int nb00, const int nb01, const int nb02,
+     const int nb03, const int ne10, const int ne11, const int ne12,
+     const int nb10, const int nb11, const int nb12, const int nb13,
+-    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    cudaStream_t stream) {
+     const int num_blocks = ne;
+     cpy_q_f32, QK4_1><<>>(
+         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+-         ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_f32_q5_0_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     GGML_ASSERT(ne % QK5_0 == 0);
+     const int num_blocks = ne / QK5_0;
+     cpy_f32_q<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_q5_0_f32_cuda(
+@@ -499,22 +472,22 @@ static void ggml_cpy_q5_0_f32_cuda(
+     const int nb00, const int nb01, const int nb02,
+     const int nb03, const int ne10, const int ne11, const int ne12,
+     const int nb10, const int nb11, const int nb12, const int nb13,
+-    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    cudaStream_t stream) {
+     const int num_blocks = ne;
+     cpy_q_f32, QK5_0><<>>(
+         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+-        ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_f32_q5_1_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     GGML_ASSERT(ne % QK5_1 == 0);
+     const int num_blocks = ne / QK5_1;
+     cpy_f32_q<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_q5_1_f32_cuda(
+@@ -523,32 +496,32 @@ static void ggml_cpy_q5_1_f32_cuda(
+     const int nb00, const int nb01, const int nb02,
+     const int nb03, const int ne10, const int ne11, const int ne12,
+     const int nb10, const int nb11, const int nb12, const int nb13,
+-    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    cudaStream_t stream) {
+     const int num_blocks = ne;
+     cpy_q_f32, QK5_1><<>>(
+         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+-        ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_f32_iq4_nl_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     GGML_ASSERT(ne % QK4_NL == 0);
+     const int num_blocks = ne / QK4_NL;
+     cpy_f32_q<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ static void ggml_cpy_f16_f16_cuda(
+     const char * cx, char * cdst, const int ne,
+     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
++    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ 
+     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+     cpy_f32_f16<<>>
+-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
++        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+ 
+ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
+@@ -585,62 +558,48 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
+     char * src0_ddc = (char *) src0->data;
+     char * src1_ddc = (char *) src1->data;
+ 
+-    char ** dest_ptrs_d = nullptr;
+-    int graph_cpynode_index = -1;
+-#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+-    if(ctx.cuda_graph->use_cpy_indirection) {
+-        dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
+-        graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
+-    }
+-#endif
+     if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+         GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
+         CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
+     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+-        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
+-        ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+-        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+-        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
+-        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+-        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
+         ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
+-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+-        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
+         ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
+-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+-        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
+         ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
+-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+-        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+-        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
+-        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+-        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+-        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
++        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+     } else {
+         GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
+                 ggml_type_name(src0->type), ggml_type_name(src1->type));
+     }
+-#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+-    if(ctx.cuda_graph->use_cpy_indirection) {
+-        ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
+-    }
+-#endif
+-
+ }
+ 
+ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh
+index 6bed0564..28b06cdd 100644
+--- a/ggml/src/ggml-cuda/cpy.cuh
++++ b/ggml/src/ggml-cuda/cpy.cuh
+@@ -7,5 +7,3 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
+ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+ 
+ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
+-
+-void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
+diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
+index 67208cba..a44788db 100644
+--- a/ggml/src/ggml-cuda/ggml-cuda.cu
++++ b/ggml/src/ggml-cuda/ggml-cuda.cu
+@@ -2477,11 +2477,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
+ 
+ #ifdef USE_CUDA_GRAPH
+ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+-    bool use_cuda_graph) {
++    std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
+ 
+     // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+-    cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
+-
++    cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+     for (int i = 0; i < cgraph->n_nodes; i++) {
+         ggml_tensor * node = cgraph->nodes[i];
+ 
+@@ -2513,11 +2512,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
+         }
+ 
+         if (node->op == GGML_OP_CPY) {
+-
+-            // Store the pointers which are updated for each token, such that these can be sent
+-            // to the device and accessed using indirection from CUDA graph
+-            cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
+-
++            // store the copy op parameter which changes with each token.
++            cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+             // store a pointer to each copy op CUDA kernel to identify it later
+             void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+             if (!ptr) {
+@@ -2525,6 +2521,10 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
+ #ifndef NDEBUG
+                 GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
+ #endif
++            } else {
++                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
++                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
++                }
+             }
+         }
+ 
+@@ -2533,12 +2533,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
+         }
+     }
+ 
+-    if (use_cuda_graph) {
+-        cuda_ctx->cuda_graph->use_cpy_indirection = true;
+-        // copy pointers to GPU so they can be accessed via indirection within CUDA graph
+-        ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
+-    }
+-
+     return use_cuda_graph;
+ }
+ 
+@@ -2593,6 +2587,51 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
+     return true;
+ }
+ 
++static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
++
++    if (cuda_graph_update_required) {
++        // Extract nodes from graph
++        // First call with null argument gets number of nodes in graph
++        CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
++        // Subsequent call with non-null argument gets nodes
++        cuda_ctx->cuda_graph->nodes.clear();
++        cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
++        cuda_ctx->cuda_graph->params.clear();
++        cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
++        if (cuda_ctx->cuda_graph->num_nodes > 0) {
++            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
++
++            // Loop over nodes, and extract kernel parameters from each node
++            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
++                cudaGraphNodeType node_type;
++                CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
++                if (node_type == cudaGraphNodeTypeKernel) {
++                    cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
++                    if (stat == cudaErrorInvalidDeviceFunction) {
++                        // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
++                        // We don't need to update blas nodes, so clear error and move on.
++                        (void)cudaGetLastError();
++                    } else {
++                        GGML_ASSERT(stat == cudaSuccess);
++                    }
++                }
++            }
++        }
++    } else {
++        // One of the arguments to the copy kernel is updated for each token, hence we need to
++        // replace that argument with the updated value in the CUDA graph
++        // on update steps, the live parameters will already be captured
++        int k = 0;
++        for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
++            if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
++                char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
++                *(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
++                CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
++            }
++        }
++    }
++}
++
+ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
+ 
+     bool cuda_graph_update_required = false;
+@@ -2652,7 +2691,8 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
+ #endif
+ 
+ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+-    bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
++   [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs,  bool & graph_evaluated_or_captured, bool & use_cuda_graph,
++    bool & cuda_graph_update_required) {
+ 
+     while (!graph_evaluated_or_captured) {
+         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
+@@ -2702,9 +2742,13 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
+         if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
+             CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+         }
+-        if (cuda_graph_update_required) { // Update graph executable
+-            update_cuda_graph_executable(cuda_ctx);
+-        }
++
++        // Perform update to graph (if required for this token), and change copy parameter (required for every token)
++        maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
++
++        // Update graph executable
++        update_cuda_graph_executable(cuda_ctx);
++
+         // Launch graph
+         CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+ #else
+@@ -2718,6 +2762,10 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
+ 
+     ggml_cuda_set_device(cuda_ctx->device);
+ 
++    // vector of pointers to CUDA cpy kernels, which are required to identify
++    // kernel parameters which need updated in the graph for each token
++    std::vector ggml_cuda_cpy_fn_ptrs;
++
+ #ifdef USE_CUDA_GRAPH
+     static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+ 
+@@ -2751,7 +2799,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
+     if (use_cuda_graph) {
+         cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
+ 
+-        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
++        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
++                             ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
+ 
+         // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
+         if (use_cuda_graph && cuda_graph_update_required) {
+@@ -2772,10 +2821,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
+         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+     }
+ 
+-    if (!use_cuda_graph) {
+-        cuda_ctx->cuda_graph->use_cpy_indirection = false;
+-    }
+-
+ #else
+     bool use_cuda_graph = false;
+     bool cuda_graph_update_required = false;
+@@ -2783,7 +2828,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
+ 
+     bool graph_evaluated_or_captured = false;
+ 
+-    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
++    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
+ 
+     return GGML_STATUS_SUCCESS;
+ }
diff --git a/llama/patches/0021-add-model-quantizations.patch b/llama/patches/0021-add-model-quantizations.patch
deleted file mode 100644
index cdc35a412..000000000
--- a/llama/patches/0021-add-model-quantizations.patch
+++ /dev/null
@@ -1,173 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Patrick Devine 
-Date: Fri, 14 Mar 2025 16:33:23 -0700
-Subject: [PATCH] add model quantizations
-
-- gemma3
-- mistral3
----
- src/llama-arch.cpp  | 36 ++++++++++++++++++++++++++++++++++++
- src/llama-arch.h    |  2 ++
- src/llama-model.cpp | 10 ++++++++++
- src/llama-quant.cpp |  4 ++++
- 4 files changed, 52 insertions(+)
-
-diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index b6f20286..13a0a988 100644
---- a/src/llama-arch.cpp
-+++ b/src/llama-arch.cpp
-@@ -37,6 +37,7 @@ static const std::map LLM_ARCH_NAMES = {
-     { LLM_ARCH_MINICPM3,         "minicpm3"         },
-     { LLM_ARCH_GEMMA,            "gemma"            },
-     { LLM_ARCH_GEMMA2,           "gemma2"           },
-+    { LLM_ARCH_GEMMA3,           "gemma3"           },
-     { LLM_ARCH_STARCODER2,       "starcoder2"       },
-     { LLM_ARCH_MAMBA,            "mamba"            },
-     { LLM_ARCH_XVERSE,           "xverse"           },
-@@ -64,6 +65,7 @@ static const std::map LLM_ARCH_NAMES = {
-     { LLM_ARCH_CHAMELEON,        "chameleon"        },
-     { LLM_ARCH_SOLAR,            "solar"            },
-     { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
-+    { LLM_ARCH_MISTRAL3,         "mistral3"         },
-     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
- };
- 
-@@ -804,6 +806,24 @@ static const std::map> LLM_TENSOR_N
-             { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
-         },
-     },
-+    {
-+        LLM_ARCH_GEMMA3,
-+        {
-+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-+            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
-+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-+            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
-+        },
-+    },
-     {
-         LLM_ARCH_STARCODER2,
-         {
-@@ -1352,6 +1372,22 @@ static const std::map> LLM_TENSOR_N
-             { LLM_TENSOR_POS_NET_ATTN_OUT,  "posnet.%d.attn_output" },
-         },
-     },
-+    {
-+        LLM_ARCH_MISTRAL3,
-+        {
-+            { LLM_TENSOR_TOKEN_EMBD,  "token_embd" },
-+            { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
-+            { LLM_TENSOR_ATTN_NORM,   "blk.%d.attn_norm" },
-+            { LLM_TENSOR_ATTN_Q,      "blk.%d.attn_q" },
-+            { LLM_TENSOR_ATTN_K,      "blk.%d.attn_k" },
-+            { LLM_TENSOR_ATTN_V,      "blk.%d.attn_v" },
-+            { LLM_TENSOR_ATTN_OUT,    "blk.%d.attn_output" },
-+            { LLM_TENSOR_FFN_NORM,    "blk.%d.ffn_norm" },
-+            { LLM_TENSOR_FFN_GATE,    "blk.%d.ffn_gate" },
-+            { LLM_TENSOR_FFN_UP,      "blk.%d.ffn_up" },
-+            { LLM_TENSOR_FFN_DOWN,    "blk.%d.ffn_down" },
-+        }
-+    },
-     {
-         LLM_ARCH_UNKNOWN,
-         {
-diff --git a/src/llama-arch.h b/src/llama-arch.h
-index ec742224..8476ae0a 100644
---- a/src/llama-arch.h
-+++ b/src/llama-arch.h
-@@ -41,6 +41,7 @@ enum llm_arch {
-     LLM_ARCH_MINICPM3,
-     LLM_ARCH_GEMMA,
-     LLM_ARCH_GEMMA2,
-+    LLM_ARCH_GEMMA3,
-     LLM_ARCH_STARCODER2,
-     LLM_ARCH_MAMBA,
-     LLM_ARCH_XVERSE,
-@@ -68,6 +69,7 @@ enum llm_arch {
-     LLM_ARCH_CHAMELEON,
-     LLM_ARCH_SOLAR,
-     LLM_ARCH_WAVTOKENIZER_DEC,
-+    LLM_ARCH_MISTRAL3,
-     LLM_ARCH_UNKNOWN,
- };
- 
-diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index ab1a07d1..db4f2685 100644
---- a/src/llama-model.cpp
-+++ b/src/llama-model.cpp
-@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
-                     default: type = LLM_TYPE_UNKNOWN;
-                }
-             } break;
-+        case LLM_ARCH_GEMMA3:
-+            {
-+            } break;
-         case LLM_ARCH_STARCODER2:
-             {
-                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-@@ -1274,6 +1277,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
-                 ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
-                 ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
-             } break;
-+        case LLM_ARCH_MISTRAL3: break;
-         default: throw std::runtime_error("unsupported model architecture");
-     }
- 
-@@ -2537,6 +2541,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
-                         layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
-                     }
-                 } break;
-+            case LLM_ARCH_GEMMA3:
-+                {
-+                } break;
-             case LLM_ARCH_STARCODER2:
-                 {
-                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-@@ -3531,6 +3538,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
-                     output   = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
-                     output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"),   {n_embd}, 0);
-                 } break;
-+            case LLM_ARCH_MISTRAL3: break;
-             default:
-                 throw std::runtime_error("unknown architecture");
-         }
-@@ -4009,6 +4017,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
-         case LLM_ARCH_GRANITE_MOE:
-         case LLM_ARCH_CHAMELEON:
-         case LLM_ARCH_SOLAR:
-+        case LLM_ARCH_MISTRAL3:
-             return LLAMA_ROPE_TYPE_NORM;
- 
-         // the pairs of head values are offset by n_rot/2
-@@ -4029,6 +4038,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
-         case LLM_ARCH_PHIMOE:
-         case LLM_ARCH_GEMMA:
-         case LLM_ARCH_GEMMA2:
-+        case LLM_ARCH_GEMMA3:
-         case LLM_ARCH_STARCODER2:
-         case LLM_ARCH_OPENELM:
-         case LLM_ARCH_GPTNEOX:
-diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
-index 6eb1da08..ebcbafa1 100644
---- a/src/llama-quant.cpp
-+++ b/src/llama-quant.cpp
-@@ -737,6 +737,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
-         // This used to be a regex, but  has an extreme cost to compile times.
-         bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
- 
-+        // don't quantize vision stuff
-+        quantize &= name.find("v.") == std::string::npos;
-+        quantize &= name.find("mm.") == std::string::npos;
-+
-         // quantize only 2D and 3D tensors (experts)
-         quantize &= (ggml_n_dims(tensor) >= 2);
- 
diff --git a/llama/patches/0021-remove-ggml-git-build-info.patch b/llama/patches/0021-remove-ggml-git-build-info.patch
new file mode 100644
index 000000000..04b7d65f2
--- /dev/null
+++ b/llama/patches/0021-remove-ggml-git-build-info.patch
@@ -0,0 +1,45 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca 
+Date: Sat, 12 Apr 2025 21:13:44 -0400
+Subject: [PATCH] remove ggml git build info
+
+---
+ ggml/CMakeLists.txt | 25 -------------------------
+ 1 file changed, 25 deletions(-)
+
+diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
+index d33f843b..a6c59f22 100644
+--- a/ggml/CMakeLists.txt
++++ b/ggml/CMakeLists.txt
+@@ -287,31 +287,6 @@ if (GGML_STANDALONE)
+         DESTINATION share/pkgconfig)
+ endif()
+ 
+-#
+-# Create CMake package
+-#
+-
+-# Generate version info based on git commit.
+-
+-if(NOT DEFINED GGML_BUILD_NUMBER)
+-    find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH)
+-    execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD
+-        WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+-        OUTPUT_VARIABLE GGML_BUILD_NUMBER
+-        OUTPUT_STRIP_TRAILING_WHITESPACE
+-    )
+-
+-    if(GGML_BUILD_NUMBER EQUAL 1)
+-        message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.")
+-    endif()
+-
+-    execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
+-        WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+-        OUTPUT_VARIABLE GGML_BUILD_COMMIT
+-        OUTPUT_STRIP_TRAILING_WHITESPACE
+-    )
+-endif()
+-
+ 
+ # Capture variables prefixed with GGML_.
+ 
diff --git a/llama/patches/0022-add-rdna4-support.patch b/llama/patches/0022-add-rdna4-support.patch
deleted file mode 100644
index 9b77ae84c..000000000
--- a/llama/patches/0022-add-rdna4-support.patch
+++ /dev/null
@@ -1,103 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Saman 
-Date: Wed, 19 Mar 2025 14:02:26 -0700
-Subject: [PATCH] add rdna4 support
-
----
- ggml/src/ggml-cuda/common.cuh    | 6 ++++--
- ggml/src/ggml-cuda/mmq.cu        | 2 +-
- ggml/src/ggml-cuda/mmq.cuh       | 4 ++--
- ggml/src/ggml-cuda/mmvq.cu       | 4 ++--
- ggml/src/ggml-cuda/vendors/hip.h | 4 ++++
- 5 files changed, 13 insertions(+), 7 deletions(-)
-
-diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
-index adf0d3ec..b24593fc 100644
---- a/ggml/src/ggml-cuda/common.cuh
-+++ b/ggml/src/ggml-cuda/common.cuh
-@@ -61,11 +61,13 @@
- #define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
- #define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
- #define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
-+#define GGML_CUDA_CC_RDNA4      (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
- 
- #define GGML_CUDA_CC_IS_RDNA(cc)  (cc >= GGML_CUDA_CC_RDNA1)
- #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
- #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
--#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
-+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
-+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
- #define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
- #define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
- 
-@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
- #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
-     c = __builtin_amdgcn_sdot4(a, b, c, false);
--#elif defined(RDNA3)
-+#elif defined(RDNA3) || defined(RDNA4)
-     c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
- #elif defined(__gfx1010__) || defined(__gfx900__)
-     int tmp1;
-diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
-index 10f2ebb1..933d945c 100644
---- a/ggml/src/ggml-cuda/mmq.cu
-+++ b/ggml/src/ggml-cuda/mmq.cu
-@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
-         return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
-     }
- 
--    return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
-+    return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
- }
-diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
-index 0451c65f..66ce2bc9 100644
---- a/ggml/src/ggml-cuda/mmq.cuh
-+++ b/ggml/src/ggml-cuda/mmq.cuh
-@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
- 
- template 
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
--#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
-+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
-     __launch_bounds__(WARP_SIZE*nwarps, 2)
--#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
-+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
- #else
- #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
-     __launch_bounds__(WARP_SIZE*nwarps, 1)
-diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
-index 4fb466ca..23ae7abc 100644
---- a/ggml/src/ggml-cuda/mmvq.cu
-+++ b/ggml/src/ggml-cuda/mmvq.cu
-@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
- 
-     constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
- 
--#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
-+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
-     constexpr int nwarps              = 1;
-     constexpr int rows_per_cuda_block = 1;
- #else
-     constexpr int nwarps              = ncols_y <= 4 ? 4 : 2;
-     constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
--#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
-+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
- 
-     const     int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
-     const     int row0 = rows_per_cuda_block*blockIdx.x;
-diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
-index 81964611..a62544b5 100644
---- a/ggml/src/ggml-cuda/vendors/hip.h
-+++ b/ggml/src/ggml-cuda/vendors/hip.h
-@@ -150,6 +150,10 @@
- #define CDNA
- #endif
- 
-+#if defined(__gfx1200__) || defined(__gfx1201__)
-+#define RDNA4
-+#endif
-+
- #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
-     defined(__gfx1150__) || defined(__gfx1151__)
- #define RDNA3
diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp
index b816cedd4..a87e23e5c 100644
--- a/llama/sampling_ext.cpp
+++ b/llama/sampling_ext.cpp
@@ -73,7 +73,7 @@ struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
     try {
         const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
         std::vector splits = {};
-        llama_model_loader ml(std::string(fname), splits, false, false, nullptr);
+        llama_model_loader ml(std::string(fname), splits, false, false, nullptr, nullptr);
         vocab->load(ml, kv);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
diff --git a/ml/backend/ggml/ggml/.rsync-filter b/ml/backend/ggml/ggml/.rsync-filter
index ddad16e21..0a9d77afc 100644
--- a/ml/backend/ggml/ggml/.rsync-filter
+++ b/ml/backend/ggml/ggml/.rsync-filter
@@ -1,5 +1,6 @@
 protect *.go
 protect *-embed.*
+include cmake/
 include include/
 include src/
 include src/CMakeLists.txt
@@ -13,6 +14,7 @@ include src/ggml-cuda/vendors/
 include src/ggml-cuda/template-instances/
 include src/ggml-hip/
 include src/ggml-metal/
+include CMakeLists.txt
 include *.c
 include *.h
 include *.cpp
@@ -20,4 +22,6 @@ include *.cu
 include *.cuh
 include *.m
 include *.metal
+include common.cmake
+include ggml-config.cmake.in
 exclude *
diff --git a/ml/backend/ggml/ggml/CMakeLists.txt b/ml/backend/ggml/ggml/CMakeLists.txt
new file mode 100644
index 000000000..a6c59f229
--- /dev/null
+++ b/ml/backend/ggml/ggml/CMakeLists.txt
@@ -0,0 +1,337 @@
+cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
+project("ggml" C CXX)
+include(CheckIncludeFileCXX)
+
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
+    set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
+    set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
+endif()
+
+if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
+    set(GGML_STANDALONE ON)
+
+    set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
+
+    # configure project version
+    # TODO
+else()
+    set(GGML_STANDALONE OFF)
+endif()
+
+if (EMSCRIPTEN)
+    set(BUILD_SHARED_LIBS_DEFAULT OFF)
+
+    option(GGML_WASM_SINGLE_FILE "ggml: embed WASM inside the generated ggml.js" ON)
+else()
+    if (MINGW)
+        set(BUILD_SHARED_LIBS_DEFAULT OFF)
+    else()
+        set(BUILD_SHARED_LIBS_DEFAULT ON)
+    endif()
+endif()
+
+# remove the lib prefix on win32 mingw
+if (WIN32)
+    set(CMAKE_STATIC_LIBRARY_PREFIX "")
+    set(CMAKE_SHARED_LIBRARY_PREFIX "")
+    set(CMAKE_SHARED_MODULE_PREFIX  "")
+endif()
+
+option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
+option(GGML_BACKEND_DL   "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF)
+
+#
+# option list
+#
+
+# TODO: mark all options as advanced when not GGML_STANDALONE
+
+if (APPLE)
+    set(GGML_METAL_DEFAULT ON)
+    set(GGML_BLAS_DEFAULT ON)
+    set(GGML_BLAS_VENDOR_DEFAULT "Apple")
+else()
+    set(GGML_METAL_DEFAULT OFF)
+    set(GGML_BLAS_DEFAULT OFF)
+    set(GGML_BLAS_VENDOR_DEFAULT "Generic")
+endif()
+
+if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH})
+    message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF")
+    set(GGML_NATIVE_DEFAULT OFF)
+else()
+    set(GGML_NATIVE_DEFAULT ON)
+endif()
+
+# defaults
+if (NOT GGML_LLAMAFILE_DEFAULT)
+    set(GGML_LLAMAFILE_DEFAULT OFF)
+endif()
+
+if (NOT GGML_CUDA_GRAPHS_DEFAULT)
+    set(GGML_CUDA_GRAPHS_DEFAULT OFF)
+endif()
+
+# general
+option(GGML_STATIC "ggml: static link libraries"                     OFF)
+option(GGML_NATIVE "ggml: optimize the build for the current system" ${GGML_NATIVE_DEFAULT})
+option(GGML_LTO    "ggml: enable link time optimization"             OFF)
+option(GGML_CCACHE "ggml: use ccache if available"                   ON)
+
+# debug
+option(GGML_ALL_WARNINGS           "ggml: enable all compiler warnings"                   ON)
+option(GGML_ALL_WARNINGS_3RD_PARTY "ggml: enable all compiler warnings in 3rd party libs" OFF)
+option(GGML_GPROF                  "ggml: enable gprof"                                   OFF)
+
+# build
+option(GGML_FATAL_WARNINGS    "ggml: enable -Werror flag"    OFF)
+
+# sanitizers
+option(GGML_SANITIZE_THREAD    "ggml: enable thread sanitizer"    OFF)
+option(GGML_SANITIZE_ADDRESS   "ggml: enable address sanitizer"   OFF)
+option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF)
+
+# instruction set specific
+if (GGML_NATIVE OR NOT GGML_NATIVE_DEFAULT)
+    set(INS_ENB OFF)
+else()
+    set(INS_ENB ON)
+endif()
+
+message(DEBUG "GGML_NATIVE         : ${GGML_NATIVE}")
+message(DEBUG "GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}")
+message(DEBUG "INS_ENB             : ${INS_ENB}")
+
+option(GGML_CPU_HBM          "ggml: use memkind for CPU HBM" OFF)
+option(GGML_CPU_AARCH64      "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
+option(GGML_CPU_KLEIDIAI     "ggml: use KleidiAI optimized kernels if applicable" OFF)
+option(GGML_AVX              "ggml: enable AVX"              ${INS_ENB})
+option(GGML_AVX_VNNI         "ggml: enable AVX-VNNI"         OFF)
+option(GGML_AVX2             "ggml: enable AVX2"             ${INS_ENB})
+option(GGML_BMI2             "ggml: enable BMI2"             ${INS_ENB})
+option(GGML_AVX512           "ggml: enable AVX512F"          OFF)
+option(GGML_AVX512_VBMI      "ggml: enable AVX512-VBMI"      OFF)
+option(GGML_AVX512_VNNI      "ggml: enable AVX512-VNNI"      OFF)
+option(GGML_AVX512_BF16      "ggml: enable AVX512-BF16"      OFF)
+if (NOT MSVC)
+    # in MSVC F16C and FMA is implied with AVX2/AVX512
+    option(GGML_FMA          "ggml: enable FMA"              ${INS_ENB})
+    option(GGML_F16C         "ggml: enable F16C"             ${INS_ENB})
+    # MSVC does not seem to support AMX
+    option(GGML_AMX_TILE     "ggml: enable AMX-TILE"         OFF)
+    option(GGML_AMX_INT8     "ggml: enable AMX-INT8"         OFF)
+    option(GGML_AMX_BF16     "ggml: enable AMX-BF16"         OFF)
+endif()
+option(GGML_LASX             "ggml: enable lasx"             ON)
+option(GGML_LSX              "ggml: enable lsx"              ON)
+option(GGML_RVV              "ggml: enable rvv"              ON)
+option(GGML_RV_ZFH           "ggml: enable riscv zfh"        OFF)
+option(GGML_VXE              "ggml: enable vxe"              ON)
+
+option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
+set(GGML_CPU_ARM_ARCH        "" CACHE STRING "ggml: CPU architecture for ARM")
+set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
+
+
+if (WIN32)
+    set(GGML_WIN_VER "0x602" CACHE STRING   "ggml: Windows version")
+endif()
+
+# ggml core
+set(GGML_SCHED_MAX_COPIES  "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
+option(GGML_CPU                             "ggml: enable CPU backend"                        ON)
+
+# 3rd party libs / backends
+option(GGML_ACCELERATE                      "ggml: enable Accelerate framework"               ON)
+option(GGML_BLAS                            "ggml: use BLAS"                                  ${GGML_BLAS_DEFAULT})
+set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
+                                            "ggml: BLAS library vendor")
+option(GGML_LLAMAFILE                       "ggml: use LLAMAFILE"                             ${GGML_LLAMAFILE_DEFAULT})
+
+option(GGML_CUDA                            "ggml: use CUDA"                                  OFF)
+option(GGML_MUSA                            "ggml: use MUSA"                                  OFF)
+option(GGML_CUDA_FORCE_MMQ                  "ggml: use mmq kernels instead of cuBLAS"         OFF)
+option(GGML_CUDA_FORCE_CUBLAS               "ggml: always use cuBLAS instead of mmq kernels"  OFF)
+option(GGML_CUDA_F16                        "ggml: use 16 bit floats for some calculations"   OFF)
+set   (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
+                                            "ggml: max. batch size for using peer access")
+option(GGML_CUDA_NO_PEER_COPY               "ggml: do not use peer to peer copies"            OFF)
+option(GGML_CUDA_NO_VMM                     "ggml: do not try to use CUDA VMM"                OFF)
+option(GGML_CUDA_FA                         "ggml: compile ggml FlashAttention CUDA kernels"  ON)
+option(GGML_CUDA_FA_ALL_QUANTS              "ggml: compile all quants for FlashAttention"     OFF)
+option(GGML_CUDA_GRAPHS                     "ggml: use CUDA graphs (llama.cpp only)"          ${GGML_CUDA_GRAPHS_DEFAULT})
+set   (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING
+                                            "ggml: cuda link binary compression mode; requires cuda 12.8+")
+set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size")
+
+option(GGML_HIP                             "ggml: use HIP"                                   OFF)
+option(GGML_HIP_GRAPHS                      "ggml: use HIP graph, experimental, slow"         OFF)
+option(GGML_HIP_NO_VMM                      "ggml: do not try to use HIP VMM"                 ON)
+option(GGML_HIP_ROCWMMA_FATTN               "ggml: enable rocWMMA for FlashAttention"         OFF)
+option(GGML_HIP_UMA                         "ggml: use HIP unified memory architecture"       OFF)
+option(GGML_VULKAN                          "ggml: use Vulkan"                                OFF)
+option(GGML_VULKAN_CHECK_RESULTS            "ggml: run Vulkan op checks"                      OFF)
+option(GGML_VULKAN_DEBUG                    "ggml: enable Vulkan debug output"                OFF)
+option(GGML_VULKAN_MEMORY_DEBUG             "ggml: enable Vulkan memory debug output"         OFF)
+option(GGML_VULKAN_SHADER_DEBUG_INFO        "ggml: enable Vulkan shader debug info"           OFF)
+option(GGML_VULKAN_PERF                     "ggml: enable Vulkan perf output"                 OFF)
+option(GGML_VULKAN_VALIDATE                 "ggml: enable Vulkan validation"                  OFF)
+option(GGML_VULKAN_RUN_TESTS                "ggml: run Vulkan tests"                          OFF)
+option(GGML_KOMPUTE                         "ggml: use Kompute"                               OFF)
+option(GGML_METAL                           "ggml: use Metal"                                 ${GGML_METAL_DEFAULT})
+option(GGML_METAL_USE_BF16                  "ggml: use bfloat if available"                   OFF)
+option(GGML_METAL_NDEBUG                    "ggml: disable Metal debugging"                   OFF)
+option(GGML_METAL_SHADER_DEBUG              "ggml: compile Metal with -fno-fast-math"         OFF)
+option(GGML_METAL_EMBED_LIBRARY             "ggml: embed Metal library"                       ${GGML_METAL})
+set   (GGML_METAL_MACOSX_VERSION_MIN "" CACHE STRING
+                                            "ggml: metal minimum macOS version")
+set   (GGML_METAL_STD "" CACHE STRING       "ggml: metal standard version (-std flag)")
+option(GGML_OPENMP                          "ggml: use OpenMP"                                ON)
+option(GGML_RPC                             "ggml: use RPC"                                   OFF)
+option(GGML_SYCL                            "ggml: use SYCL"                                  OFF)
+option(GGML_SYCL_F16                        "ggml: use 16 bit floats for sycl calculations"   OFF)
+option(GGML_SYCL_GRAPH                      "ggml: enable graphs in the SYCL backend"         ON)
+set   (GGML_SYCL_TARGET "INTEL" CACHE STRING
+                                            "ggml: sycl target device")
+set   (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
+                                            "ggml: sycl device architecture")
+
+option(GGML_OPENCL                          "ggml: use OpenCL"                                OFF)
+option(GGML_OPENCL_PROFILING                "ggml: use OpenCL profiling (increases overhead)" OFF)
+option(GGML_OPENCL_EMBED_KERNELS            "ggml: embed kernels"                             ON)
+option(GGML_OPENCL_USE_ADRENO_KERNELS       "ggml: use optimized kernels for Adreno"          ON)
+set   (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
+                                            "gmml: OpenCL API version to target")
+
+# toolchain for vulkan-shaders-gen
+set   (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
+
+# extra artifacts
+option(GGML_BUILD_TESTS    "ggml: build tests"    ${GGML_STANDALONE})
+option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
+
+#
+# dependencies
+#
+
+set(CMAKE_C_STANDARD 11)
+set(CMAKE_C_STANDARD_REQUIRED true)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED true)
+
+set(THREADS_PREFER_PTHREAD_FLAG ON)
+
+find_package(Threads REQUIRED)
+
+include(GNUInstallDirs)
+
+#
+# build the library
+#
+
+add_subdirectory(src)
+
+#
+# tests and examples
+#
+
+if (GGML_BUILD_TESTS)
+    enable_testing()
+    add_subdirectory(tests)
+endif ()
+
+if (GGML_BUILD_EXAMPLES)
+    add_subdirectory(examples)
+endif ()
+
+#
+# install
+#
+
+include(CMakePackageConfigHelpers)
+
+# all public headers
+set(GGML_PUBLIC_HEADERS
+    include/ggml.h
+    include/ggml-cpu.h
+    include/ggml-alloc.h
+    include/ggml-backend.h
+    include/ggml-blas.h
+    include/ggml-cann.h
+    include/ggml-cpp.h
+    include/ggml-cuda.h
+    include/ggml-kompute.h
+    include/ggml-opt.h
+    include/ggml-metal.h
+    include/ggml-rpc.h
+    include/ggml-sycl.h
+    include/ggml-vulkan.h
+    include/gguf.h)
+
+set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
+#if (GGML_METAL)
+#    set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
+#endif()
+install(TARGETS ggml LIBRARY PUBLIC_HEADER)
+install(TARGETS ggml-base LIBRARY)
+
+if (GGML_STANDALONE)
+    configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in
+        ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
+        @ONLY)
+
+    install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
+        DESTINATION share/pkgconfig)
+endif()
+
+
+# Capture variables prefixed with GGML_.
+
+set(variable_set_statements
+"
+####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() #######
+####### Any changes to this file will be overwritten by the next CMake run        #######
+
+")
+
+set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS})
+
+get_cmake_property(all_variables VARIABLES)
+foreach(variable_name IN LISTS all_variables)
+    if(variable_name MATCHES "^GGML_")
+        string(REPLACE ";" "\\;"
+               variable_value "${${variable_name}}")
+
+        set(variable_set_statements
+            "${variable_set_statements}set(${variable_name} \"${variable_value}\")\n")
+    endif()
+endforeach()
+
+set(GGML_VARIABLES_EXPANDED ${variable_set_statements})
+
+# Create the CMake package and set install location.
+
+set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER})
+set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header  files")
+set(GGML_LIB_INSTALL_DIR     ${CMAKE_INSTALL_LIBDIR}     CACHE PATH "Location of library files")
+set(GGML_BIN_INSTALL_DIR     ${CMAKE_INSTALL_BINDIR}     CACHE PATH "Location of binary  files")
+
+configure_package_config_file(
+        ${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in
+        ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
+    INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml
+    PATH_VARS GGML_INCLUDE_INSTALL_DIR
+              GGML_LIB_INSTALL_DIR
+              GGML_BIN_INSTALL_DIR)
+
+write_basic_package_version_file(
+        ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
+    VERSION ${GGML_INSTALL_VERSION}
+    COMPATIBILITY SameMajorVersion)
+
+install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
+              ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
+        DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
diff --git a/ml/backend/ggml/ggml/cmake/common.cmake b/ml/backend/ggml/ggml/cmake/common.cmake
new file mode 100644
index 000000000..1976d0ae9
--- /dev/null
+++ b/ml/backend/ggml/ggml/cmake/common.cmake
@@ -0,0 +1,26 @@
+function(ggml_get_flags CCID CCVER)
+    set(C_FLAGS "")
+    set(CXX_FLAGS "")
+
+    if (CCID MATCHES "Clang")
+        set(C_FLAGS   -Wunreachable-code-break -Wunreachable-code-return)
+        set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
+
+        if (
+            (CCID STREQUAL "Clang"      AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
+            (CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
+        )
+            list(APPEND C_FLAGS -Wdouble-promotion)
+        endif()
+    elseif (CCID STREQUAL "GNU")
+        set(C_FLAGS   -Wdouble-promotion)
+        set(CXX_FLAGS -Wno-array-bounds)
+
+        if (CCVER VERSION_GREATER_EQUAL 8.1.0)
+            list(APPEND CXX_FLAGS -Wextra-semi)
+        endif()
+    endif()
+
+    set(GF_C_FLAGS   ${C_FLAGS}   PARENT_SCOPE)
+    set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
+endfunction()
diff --git a/ml/backend/ggml/ggml/cmake/ggml-config.cmake.in b/ml/backend/ggml/ggml/cmake/ggml-config.cmake.in
new file mode 100644
index 000000000..8c2dc31c6
--- /dev/null
+++ b/ml/backend/ggml/ggml/cmake/ggml-config.cmake.in
@@ -0,0 +1,152 @@
+
+@GGML_VARIABLES_EXPANDED@
+
+@PACKAGE_INIT@
+
+set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
+set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
+#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
+
+find_package(Threads REQUIRED)
+
+find_library(GGML_LIBRARY ggml
+    REQUIRED
+    HINTS ${GGML_LIB_DIR}
+    NO_CMAKE_FIND_ROOT_PATH)
+
+add_library(ggml::ggml UNKNOWN IMPORTED)
+set_target_properties(ggml::ggml
+    PROPERTIES
+        IMPORTED_LOCATION "${GGML_LIBRARY}")
+
+find_library(GGML_BASE_LIBRARY ggml-base
+    REQUIRED
+    HINTS ${GGML_LIB_DIR}
+    NO_CMAKE_FIND_ROOT_PATH)
+
+add_library(ggml::ggml-base UNKNOWN IMPORTED)
+set_target_properties(ggml::ggml-base
+    PROPERTIES
+        IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
+
+if (NOT GGML_SHARED_LIB)
+    if (APPLE AND GGML_ACCELERATE)
+        find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
+        list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
+    endif()
+
+    if (GGML_OPENMP)
+        find_package(OpenMP REQUIRED)
+        list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
+    endif()
+
+    if (GGML_CPU_HBM)
+        find_library(memkind memkind REQUIRED)
+        list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
+    endif()
+
+    if (GGML_BLAS)
+        find_package(BLAS REQUIRED)
+        list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
+        list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS   ${BLAS_LINKER_FLAGS})
+    endif()
+
+    if (GGML_CUDA)
+        find_package(CUDAToolkit REQUIRED)
+    endif()
+
+    if (GGML_METAL)
+        find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
+        find_library(METAL_FRAMEWORK    Metal REQUIRED)
+        find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
+
+        list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES
+                    ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
+    endif()
+
+    if (GGML_VULKAN)
+        find_package(Vulkan REQUIRED)
+        list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan)
+    endif()
+
+    if (GGML_HIP)
+        find_package(hip     REQUIRED)
+        find_package(hipblas REQUIRED)
+        find_package(rocblas REQUIRED)
+        list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
+    endif()
+
+    if (GGML_SYCL)
+        find_package(DNNL)
+        if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
+            list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
+        endif()
+        if (WIN32)
+            find_package(IntelSYCL REQUIRED)
+            find_package(MKL       REQUIRED)
+            list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
+        endif()
+    endif()
+endif()
+
+set(_ggml_all_targets "")
+foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
+    string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
+    string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
+
+    find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
+        REQUIRED
+        HINTS ${GGML_LIB_DIR}
+        NO_CMAKE_FIND_ROOT_PATH)
+
+    message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
+
+    add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
+    set_target_properties(ggml::${_ggml_backend}
+        PROPERTIES
+            INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
+            IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
+            IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
+            INTERFACE_COMPILE_FEATURES c_std_90
+            POSITION_INDEPENDENT_CODE ON)
+
+    string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
+    if(is_cpu_variant)
+        list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
+        set_target_properties(ggml::${_ggml_backend}
+           PROPERTIES
+               INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
+
+        if(GGML_CPU_INTERFACE_LINK_OPTIONS)
+            set_target_properties(ggml::${_ggml_backend}
+                PROPERTIES
+                    INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
+        endif()
+
+    else()
+        list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
+        set_target_properties(ggml::${_ggml_backend}
+            PROPERTIES
+                INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
+
+        if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
+            set_target_properties(ggml::${_ggml_backend}
+                PROPERTIES
+                    INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
+        endif()
+    endif()
+
+    list(APPEND _ggml_all_targets ggml::${_ggml_backend})
+endforeach()
+
+list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
+set_target_properties(ggml::ggml
+    PROPERTIES
+        INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}")
+
+add_library(ggml::all INTERFACE IMPORTED)
+set_target_properties(ggml::all
+    PROPERTIES
+        INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
+
+check_required_components(ggml)
diff --git a/ml/backend/ggml/ggml/include/ggml-alloc.h b/ml/backend/ggml/ggml/include/ggml-alloc.h
index 23600eea9..2cb150fd2 100644
--- a/ml/backend/ggml/ggml/include/ggml-alloc.h
+++ b/ml/backend/ggml/ggml/include/ggml-alloc.h
@@ -19,7 +19,7 @@ struct ggml_tallocr {
 };
 
 GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
-GGML_API void                ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
+GGML_API enum ggml_status    ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
 
 // Graph allocator
 /*
diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h
index fc9571c82..64671495b 100644
--- a/ml/backend/ggml/ggml/include/ggml-backend.h
+++ b/ml/backend/ggml/ggml/include/ggml-backend.h
@@ -56,7 +56,7 @@ extern "C" {
     GGML_API void                           ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
     GGML_API void *                         ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
     GGML_API size_t                         ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
-    GGML_API void                           ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API enum ggml_status               ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
     GGML_API size_t                         ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
     GGML_API size_t                         ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);
     GGML_API size_t                         ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
@@ -342,8 +342,8 @@ extern "C" {
     GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
 
     // Tensor initialization
-    GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
-    GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
+    GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
+    GGML_API enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor);
 
     // CPU buffer types are always available
     GGML_API ggml_backend_buffer_t      ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
diff --git a/ml/backend/ggml/ggml/include/ggml-cpu.h b/ml/backend/ggml/ggml/include/ggml-cpu.h
index b48cc560e..f5e11f1e1 100644
--- a/ml/backend/ggml/ggml/include/ggml-cpu.h
+++ b/ml/backend/ggml/ggml/include/ggml-cpu.h
@@ -80,6 +80,7 @@ extern "C" {
     GGML_BACKEND_API int ggml_cpu_has_avx        (void);
     GGML_BACKEND_API int ggml_cpu_has_avx_vnni   (void);
     GGML_BACKEND_API int ggml_cpu_has_avx2       (void);
+    GGML_BACKEND_API int ggml_cpu_has_bmi2       (void);
     GGML_BACKEND_API int ggml_cpu_has_f16c       (void);
     GGML_BACKEND_API int ggml_cpu_has_fma        (void);
     GGML_BACKEND_API int ggml_cpu_has_avx512     (void);
diff --git a/ml/backend/ggml/ggml/include/ggml-rpc.h b/ml/backend/ggml/ggml/include/ggml-rpc.h
index ade6c3b0e..4e0d210f8 100644
--- a/ml/backend/ggml/ggml/include/ggml-rpc.h
+++ b/ml/backend/ggml/ggml/include/ggml-rpc.h
@@ -17,7 +17,9 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
 
 GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
 
-GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
+GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
+                                                    const char * cache_dir,
+                                                    size_t free_mem, size_t total_mem);
 
 GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
 
diff --git a/ml/backend/ggml/ggml/include/ggml.h b/ml/backend/ggml/ggml/include/ggml.h
index 8d269a9cc..d19fc1678 100644
--- a/ml/backend/ggml/ggml/include/ggml.h
+++ b/ml/backend/ggml/ggml/include/ggml.h
@@ -454,6 +454,7 @@ extern "C" {
         GGML_OP_RMS_NORM,
         GGML_OP_RMS_NORM_BACK,
         GGML_OP_GROUP_NORM,
+        GGML_OP_L2_NORM,
 
         GGML_OP_MUL_MAT,
         GGML_OP_MUL_MAT_ID,
@@ -503,20 +504,16 @@ extern "C" {
         GGML_OP_ADD_REL_POS,
         GGML_OP_RWKV_WKV6,
         GGML_OP_GATED_LINEAR_ATTN,
+        GGML_OP_RWKV_WKV7,
 
         GGML_OP_UNARY,
 
-        GGML_OP_MAP_UNARY,
-        GGML_OP_MAP_BINARY,
-
-        GGML_OP_MAP_CUSTOM1_F32,
-        GGML_OP_MAP_CUSTOM2_F32,
-        GGML_OP_MAP_CUSTOM3_F32,
-
         GGML_OP_MAP_CUSTOM1,
         GGML_OP_MAP_CUSTOM2,
         GGML_OP_MAP_CUSTOM3,
 
+        GGML_OP_CUSTOM,
+
         GGML_OP_CROSS_ENTROPY_LOSS,
         GGML_OP_CROSS_ENTROPY_LOSS_BACK,
         GGML_OP_OPT_STEP_ADAMW,
@@ -1096,6 +1093,18 @@ extern "C" {
             int                   n_groups,
             float                 eps);
 
+    // l2 normalize along rows
+    // used in rwkv v7
+    GGML_API struct ggml_tensor * ggml_l2_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
+    GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
     // a - x
     // b - dy
     GGML_API struct ggml_tensor * ggml_rms_norm_back(
@@ -1709,24 +1718,29 @@ extern "C" {
             float                 p0,
             float                 p1);
 
-    // nearest interpolate
+    enum ggml_scale_mode {
+        GGML_SCALE_MODE_NEAREST  = 0,
+        GGML_SCALE_MODE_BILINEAR = 1,
+    };
+
+    // interpolate
     // multiplies ne0 and ne1 by scale factor
-    // used in stable-diffusion
     GGML_API struct ggml_tensor * ggml_upscale(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            int                   scale_factor);
+            int                   scale_factor,
+            enum ggml_scale_mode  mode);
 
-    // nearest interpolate
-    // nearest interpolate to specified dimensions
-    // used in tortoise.cpp
+    // interpolate
+    // interpolate scale to specified dimensions
     GGML_API struct ggml_tensor * ggml_upscale_ext(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   ne0,
             int                   ne1,
             int                   ne2,
-            int                   ne3);
+            int                   ne3,
+            enum ggml_scale_mode  mode);
 
     // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
     GGML_API struct ggml_tensor * ggml_pad(
@@ -1787,11 +1801,11 @@ extern "C" {
 
 #define GGML_KQ_MASK_PAD 64
 
-    // q:    [n_embd, n_batch,     n_head,    1]
-    // k:    [n_embd, n_kv,        n_head_kv, 1]
-    // v:    [n_embd, n_kv,        n_head_kv, 1] !! not transposed !!
-    // mask: [n_kv,   n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
-    // res:  [n_embd, n_head,      n_batch,   1] !! permuted !!
+    // q:    [n_embd_k, n_batch,     n_head,    1]
+    // k:    [n_embd_k, n_kv,        n_head_kv, 1]
+    // v:    [n_embd_v, n_kv,        n_head_kv, 1] !! not transposed !!
+    // mask: [n_kv,     n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+    // res:  [n_embd_v, n_head,      n_batch,   1] !! permuted !!
     GGML_API struct ggml_tensor * ggml_flash_attn_ext(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
@@ -1900,85 +1914,18 @@ extern "C" {
             struct ggml_tensor  * state,
             float scale);
 
+    GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * r,
+            struct ggml_tensor  * w,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * state);
+
     // custom operators
 
-    typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
-    typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
-
-    typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
-    typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
-    typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
-            struct ggml_context        * ctx,
-            struct ggml_tensor         * a,
-                   ggml_unary_op_f32_t   fun),
-        "use ggml_map_custom1 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
-            struct ggml_context        * ctx,
-            struct ggml_tensor         * a,
-                   ggml_unary_op_f32_t   fun),
-        "use ggml_map_custom1_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b,
-                   ggml_binary_op_f32_t   fun),
-        "use ggml_map_custom2 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b,
-                   ggml_binary_op_f32_t   fun),
-        "use ggml_map_custom2_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-                   ggml_custom1_op_f32_t   fun),
-        "use ggml_map_custom1 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-                   ggml_custom1_op_f32_t   fun),
-        "use ggml_map_custom1_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-                   ggml_custom2_op_f32_t   fun),
-        "use ggml_map_custom2 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-                   ggml_custom2_op_f32_t   fun),
-        "use ggml_map_custom2_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-            struct ggml_tensor           * c,
-                   ggml_custom3_op_f32_t   fun),
-        "use ggml_map_custom3 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-            struct ggml_tensor           * c,
-                   ggml_custom3_op_f32_t   fun),
-        "use ggml_map_custom3_inplace instead");
-
-    // custom operators v2
-
     typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
     typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
     typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
@@ -2034,6 +1981,30 @@ extern "C" {
             int                     n_tasks,
             void                  * userdata);
 
+    typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata);
+
+    GGML_API struct ggml_tensor * ggml_custom_4d(
+            struct ggml_context * ctx,
+            enum ggml_type        type,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3,
+            struct ggml_tensor ** args,
+            int                   n_args,
+            ggml_custom_op_t      fun,
+            int                   n_tasks,
+            void                * userdata);
+
+    GGML_API struct ggml_tensor * ggml_custom_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor ** args,
+            int                   n_args,
+            ggml_custom_op_t      fun,
+            int                   n_tasks,
+            void                * userdata);
+
     // loss function
 
     GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
@@ -2150,7 +2121,11 @@ extern "C" {
 #        define GGML_RESTRICT
 #    endif
 #else
-#    define GGML_RESTRICT restrict
+#    if defined (_MSC_VER) && (__STDC_VERSION__ < 201112L)
+#        define GGML_RESTRICT __restrict
+#    else
+#        define GGML_RESTRICT restrict
+#    endif
 #endif
     typedef void (*ggml_to_float_t)  (const void  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
     typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void  * GGML_RESTRICT y, int64_t k);
diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt
index 4564df91a..d6b393a21 100644
--- a/ml/backend/ggml/ggml/src/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/CMakeLists.txt
@@ -1,4 +1,5 @@
 include(CheckCXXCompilerFlag)
+include("../cmake/common.cmake")
 
 add_compile_definitions(GGML_SCHED_MAX_COPIES=${GGML_SCHED_MAX_COPIES})
 
@@ -24,33 +25,6 @@ if (NOT MSVC)
     endif()
 endif()
 
-function(ggml_get_flags CCID CCVER)
-    set(C_FLAGS "")
-    set(CXX_FLAGS "")
-
-    if (CCID MATCHES "Clang")
-        set(C_FLAGS   -Wunreachable-code-break -Wunreachable-code-return)
-        set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
-
-        if (
-            (CCID STREQUAL "Clang"      AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
-            (CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
-        )
-            list(APPEND C_FLAGS -Wdouble-promotion)
-        endif()
-    elseif (CCID STREQUAL "GNU")
-        set(C_FLAGS   -Wdouble-promotion)
-        set(CXX_FLAGS -Wno-array-bounds)
-
-        if (CCVER VERSION_GREATER_EQUAL 8.1.0)
-            list(APPEND CXX_FLAGS -Wextra-semi)
-        endif()
-    endif()
-
-    set(GF_C_FLAGS   ${C_FLAGS}   PARENT_SCOPE)
-    set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
-endfunction()
-
 if (GGML_FATAL_WARNINGS)
     if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
         list(APPEND C_FLAGS   -Werror)
@@ -91,7 +65,7 @@ if (GGML_LTO)
     endif()
 endif()
 
-if (GGML_CCACHE)
+if (GGML_CCACHE AND NOT CMAKE_C_COMPILER_LAUNCHER AND NOT CMAKE_CXX_COMPILER_LAUNCHER)
     find_program(GGML_CCACHE_FOUND ccache)
     find_program(GGML_SCCACHE_FOUND sccache)
 
@@ -102,7 +76,11 @@ if (GGML_CCACHE)
             set(GGML_CCACHE_VARIANT sccache)
         endif()
         # TODO: should not be set globally
-        set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
+        if (GGML_SYCL AND GGML_CCACHE_FOUND AND WIN32)
+            set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache compiler_type=icl")
+        else ()
+            set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
+        endif ()
         set(ENV{CCACHE_SLOPPINESS} time_macros)
         message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
     else()
@@ -226,6 +204,9 @@ add_library(ggml-base
             gguf.cpp)
 
 target_include_directories(ggml-base PRIVATE .)
+if (GGML_BACKEND_DL)
+    target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
+endif()
 
 add_library(ggml
             ggml-backend-reg.cpp)
@@ -233,7 +214,7 @@ add_library(ggml
 target_link_libraries(ggml PUBLIC ggml-base)
 
 if (CMAKE_SYSTEM_NAME MATCHES "Linux")
-    target_link_libraries(ggml PRIVATE dl)
+    target_link_libraries(ggml PRIVATE dl stdc++fs)
 endif()
 
 function(ggml_add_backend_library backend)
@@ -286,7 +267,7 @@ function(ggml_add_cpu_backend_variant tag_name)
     set(GGML_CPU_TAG_NAME ${tag_name})
     # other: OPENMP LLAMAFILE CPU_HBM
     foreach (feat NATIVE
-                  AVX AVX2 AVX_VNNI FMA F16C
+                  AVX AVX2 BMI2 AVX_VNNI FMA F16C
                   AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
                   AMX_TILE AMX_INT8 AMX_BF16)
         set(GGML_${feat} OFF)
@@ -308,10 +289,10 @@ if (GGML_CPU_ALL_VARIANTS)
     endif()
     add_custom_target(ggml-cpu)
     ggml_add_cpu_backend_variant(sandybridge    AVX)
-    ggml_add_cpu_backend_variant(haswell        AVX F16C AVX2 FMA)
-    ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 FMA AVX512)
-    ggml_add_cpu_backend_variant(icelake        AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
-    ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 FMA AVX_VNNI)
+    ggml_add_cpu_backend_variant(haswell        AVX F16C AVX2 BMI2 FMA)
+    ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 BMI2 FMA AVX512)
+    ggml_add_cpu_backend_variant(icelake        AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
+    ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 BMI2 FMA AVX_VNNI)
 elseif (GGML_CPU)
     ggml_add_cpu_backend_variant_impl("")
 endif()
@@ -346,6 +327,10 @@ if (CMAKE_SYSTEM_NAME MATCHES "Android")
     target_link_libraries(ggml-base PRIVATE dl)
 endif()
 
+if(CMAKE_SYSTEM_NAME MATCHES "visionOS")
+    target_compile_definitions(ggml-base PUBLIC _DARWIN_C_SOURCE)
+endif()
+
 if (BUILD_SHARED_LIBS)
     foreach (target ggml-base ggml)
         set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON)
diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c
index 7244a9cbb..a3d3f6901 100644
--- a/ml/backend/ggml/ggml/src/ggml-alloc.c
+++ b/ml/backend/ggml/ggml/src/ggml-alloc.c
@@ -89,7 +89,7 @@ struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer) {
     return talloc;
 }
 
-void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) {
+enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) {
     size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
     size = GGML_PAD(size, talloc->alignment);
 
@@ -104,7 +104,7 @@ void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tenso
 
     assert(((uintptr_t)addr % talloc->alignment) == 0);
 
-    ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
+    return ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
 }
 
 // dynamic tensor allocator
@@ -933,42 +933,51 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
 
 // utils
 
+static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
+    for (size_t i = 0; i < *n_buffers; i++) {
+        ggml_backend_buffer_free((*buffers)[i]);
+    }
+    free(*buffers);
+}
+
 static bool alloc_tensor_range(struct ggml_context * ctx,
         struct ggml_tensor * first, struct ggml_tensor * last,
         ggml_backend_buffer_type_t buft, size_t size,
         ggml_backend_buffer_t ** buffers, size_t * n_buffers) {
+
     ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
     if (buffer == NULL) {
-#ifndef NDEBUG
-        GGML_LOG_DEBUG("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
-#endif
-        for (size_t i = 0; i < *n_buffers; i++) {
-            ggml_backend_buffer_free((*buffers)[i]);
-        }
-        free(*buffers);
+        GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
+        free_buffers(buffers, n_buffers);
         return false;
     }
 
-    struct ggml_tallocr tallocr = ggml_tallocr_new(buffer);
-
-    for (struct ggml_tensor * t = first; t != last; t = ggml_get_next_tensor(ctx, t)) {
-        if (t->data == NULL) {
-            if (t->view_src == NULL) {
-                ggml_tallocr_alloc(&tallocr, t);
-            } else if (t->buffer == NULL) {
-                ggml_backend_view_init(t);
-            }
-        } else {
-            if (t->view_src != NULL && t->buffer == NULL) {
-                // view of a pre-allocated tensor
-                ggml_backend_view_init(t);
-            }
-        }
-    }
-
     *buffers = realloc(*buffers, sizeof(ggml_backend_buffer_t) * (*n_buffers + 1));
     (*buffers)[(*n_buffers)++] = buffer;
 
+    struct ggml_tallocr tallocr = ggml_tallocr_new(buffer);
+
+    for (struct ggml_tensor * t = first; t != last; t = ggml_get_next_tensor(ctx, t)) {
+        enum ggml_status status = GGML_STATUS_SUCCESS;
+        if (t->data == NULL) {
+            if (t->view_src == NULL) {
+                status = ggml_tallocr_alloc(&tallocr, t);
+            } else if (t->buffer == NULL) {
+                status = ggml_backend_view_init(t);
+            }
+        } else {
+            if (t->view_src != NULL && t->buffer == NULL) {
+                // view of a pre-allocated tensor
+                status = ggml_backend_view_init(t);
+            }
+        }
+        if (status != GGML_STATUS_SUCCESS) {
+            GGML_LOG_ERROR("%s: failed to initialize tensor %s\n", __func__, t->name);
+            free_buffers(buffers, n_buffers);
+            return false;
+        }
+    }
+
     return true;
 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-impl.h b/ml/backend/ggml/ggml/src/ggml-backend-impl.h
index d1c2d76d8..c36c12d65 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-backend-impl.h
@@ -44,7 +44,7 @@ extern "C" {
         // base address of the buffer
         void *       (*get_base)     (ggml_backend_buffer_t buffer);
         // (optional) initialize a tensor in the buffer (eg. add tensor extras)
-        void         (*init_tensor)  (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+        enum ggml_status (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
         // tensor data access
         void         (*memset_tensor)(ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
         void         (*set_tensor)   (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
index 799af5f3a..1487f322f 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
@@ -2,14 +2,13 @@
 #include "ggml-backend.h"
 #include "ggml-impl.h"
 #include 
-#include 
 #include 
 #include 
-#include 
 #include 
 #include 
 #include 
 #include 
+#include 
 
 #ifdef _WIN32
 #    define WIN32_LEAN_AND_MEAN
@@ -66,6 +65,34 @@
 #include "ggml-kompute.h"
 #endif
 
+// disable C++17 deprecation warning for std::codecvt_utf8
+#if defined(__clang__)
+#    pragma clang diagnostic push
+#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#endif
+
+namespace fs = std::filesystem;
+
+static std::string path_str(const fs::path & path) {
+    std::string u8path;
+    try {
+#if defined(__cpp_lib_char8_t)
+        // C++20 and later: u8string() returns std::u8string
+        std::u8string u8str = path.u8string();
+        u8path = std::string(reinterpret_cast(u8str.c_str()));
+#else
+        // C++17: u8string() returns std::string
+        u8path = path.u8string();
+#endif
+    } catch (...) {
+    }
+    return u8path;
+}
+
+#if defined(__clang__)
+#    pragma clang diagnostic pop
+#endif
+
 #ifdef _WIN32
 
 using dl_handle = std::remove_pointer_t;
@@ -76,12 +103,12 @@ struct dl_handle_deleter {
     }
 };
 
-static dl_handle * dl_load_library(const std::filesystem::path & path) {
+static dl_handle * dl_load_library(const fs::path & path) {
     // suppress error dialogs for missing DLLs
     DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
     SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
 
-    HMODULE handle = LoadLibraryW(path.c_str());
+    HMODULE handle = LoadLibraryW(path.wstring().c_str());
 
     SetErrorMode(old_mode);
 
@@ -109,8 +136,8 @@ struct dl_handle_deleter {
     }
 };
 
-static void * dl_load_library(const std::filesystem::path & path) {
-    dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
+static void * dl_load_library(const fs::path & path) {
+    dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
 
     return handle;
 }
@@ -121,25 +148,6 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
 
 #endif
 
-static std::string path_to_string(const std::filesystem::path & path)
-{
-#ifdef _WIN32
-    const std::wstring wstr = path.wstring();
-    const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr);
-    if (size_needed <= 0) {
-        return std::string();
-    }
-
-    // size_needed includes the null terminator
-    std::string str(size_needed - 1, '\0');
-    WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr);
-    return str;
-#else
-    return path.string();
-#endif
-}
-
-
 using dl_handle_ptr = std::unique_ptr;
 
 struct ggml_backend_reg_entry {
@@ -221,11 +229,11 @@ struct ggml_backend_registry {
         );
     }
 
-    ggml_backend_reg_t load_backend(const std::filesystem::path & path, bool silent) {
+    ggml_backend_reg_t load_backend(const fs::path & path, bool silent) {
         dl_handle_ptr handle { dl_load_library(path) };
         if (!handle) {
             if (!silent) {
-                GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str());
+                GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str());
             }
             return nullptr;
         }
@@ -233,7 +241,7 @@ struct ggml_backend_registry {
         auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
         if (score_fn && score_fn() == 0) {
             if (!silent) {
-                GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str());
+                GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_str(path).c_str());
             }
             return nullptr;
         }
@@ -241,7 +249,7 @@ struct ggml_backend_registry {
         auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
         if (!backend_init_fn) {
             if (!silent) {
-                GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str());
+                GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_str(path).c_str());
             }
             return nullptr;
         }
@@ -250,16 +258,17 @@ struct ggml_backend_registry {
         if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
             if (!silent) {
                 if (!reg) {
-                    GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str());
+                    GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n",
+                        __func__, path_str(path).c_str());
                 } else {
                     GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
-                        __func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
+                        __func__, path_str(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
                 }
             }
             return nullptr;
         }
 
-        GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str());
+        GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str());
 
         register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
 
@@ -402,7 +411,7 @@ void ggml_backend_unload(ggml_backend_reg_t reg) {
     get_reg().unload_backend(reg, true);
 }
 
-static std::filesystem::path get_executable_path() {
+static fs::path get_executable_path() {
 #if defined(__APPLE__)
     // get executable path
     std::vector path;
@@ -414,9 +423,15 @@ static std::filesystem::path get_executable_path() {
         }
         path.resize(size);
     }
-
-    return std::filesystem::path(path.data()).parent_path();
+    std::string base_path(path.data(), size);
+    // remove executable name
+    auto last_slash = base_path.find_last_of('/');
+    if (last_slash != std::string::npos) {
+        base_path = base_path.substr(0, last_slash);
+    }
+    return base_path + "/";
 #elif defined(__linux__) || defined(__FreeBSD__)
+    std::string base_path = ".";
     std::vector path(1024);
     while (true) {
         // get executable path
@@ -429,48 +444,63 @@ static std::filesystem::path get_executable_path() {
             break;
         }
         if (len < (ssize_t) path.size()) {
-            return std::filesystem::path(path.data()).parent_path();
+            base_path = std::string(path.data(), len);
+            // remove executable name
+            auto last_slash = base_path.find_last_of('/');
+            if (last_slash != std::string::npos) {
+                base_path = base_path.substr(0, last_slash);
+            }
+            break;
         }
         path.resize(path.size() * 2);
     }
+
+    return base_path + "/";
 #elif defined(_WIN32)
     std::vector path(MAX_PATH);
     DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
     if (len == 0) {
         return {};
     }
-
-    return std::filesystem::path(path.data()).parent_path();
-#endif
+    std::wstring base_path(path.data(), len);
+    // remove executable name
+    auto last_slash = base_path.find_last_of('\\');
+    if (last_slash != std::string::npos) {
+        base_path = base_path.substr(0, last_slash);
+    }
+    return base_path + L"\\";
+#else
     return {};
-}
-
-static std::string backend_filename_prefix() {
-#ifdef _WIN32
-    return "ggml-";
-#else
-    return "libggml-";
 #endif
 }
 
-static std::string backend_filename_suffix() {
+static fs::path backend_filename_prefix() {
 #ifdef _WIN32
-    return ".dll";
+    return fs::u8path("ggml-");
 #else
-    return ".so";
+    return fs::u8path("libggml-");
+#endif
+}
+
+static fs::path backend_filename_extension() {
+#ifdef _WIN32
+    return fs::u8path(".dll");
+#else
+    return fs::u8path(".so");
 #endif
 }
 
 static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
     // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
-     // TODO: search system paths
-    namespace fs = std::filesystem;
-    std::string file_prefix = backend_filename_prefix() + name + "-";
-    std::vector search_paths;
+    const fs::path name_path = fs::u8path(name);
+    const fs::path file_prefix = backend_filename_prefix().native() + name_path.native() + fs::u8path("-").native();
+    const fs::path file_extension = backend_filename_extension();
 
+    std::vector search_paths;
     if (user_search_path == nullptr) {
-        search_paths.push_back(fs::current_path());
+        // default search paths: executable directory, current directory
         search_paths.push_back(get_executable_path());
+        search_paths.push_back(fs::current_path());
     } else {
         search_paths.push_back(fs::u8path(user_search_path));
     }
@@ -480,31 +510,35 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
 
     for (const auto & search_path : search_paths) {
         if (!fs::exists(search_path)) {
+            GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str());
             continue;
         }
         fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
         for (const auto & entry : dir_it) {
             if (entry.is_regular_file()) {
-                std::string filename = entry.path().filename().string();
-                std::string ext = entry.path().extension().string();
-                if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
-                    dl_handle_ptr handle { dl_load_library(entry.path()) };
-                    if (!handle) {
-                        GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
-                        continue;
+                auto filename = entry.path().filename();
+                auto ext = entry.path().extension();
+                if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
+                    dl_handle_ptr handle { dl_load_library(entry) };
+                    if (!handle && !silent) {
+                        GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str());
                     }
-
-                    auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
-                    if (!score_fn) {
-                        GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
-                        continue;
-                    }
-
-                    int s = score_fn();
-                    GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
-                    if (s > best_score) {
-                        best_score = s;
-                        best_path = entry.path();
+                    if (handle) {
+                        auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
+                        if (score_fn) {
+                            int s = score_fn();
+#ifndef NDEBUG
+                            GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_str(entry.path()).c_str(), s);
+#endif
+                            if (s > best_score) {
+                                best_score = s;
+                                best_path = entry.path();
+                            }
+                        } else {
+                            if (!silent) {
+                                GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, path_str(entry.path()).c_str());
+                            }
+                        }
                     }
                 }
             }
@@ -514,7 +548,8 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
     if (best_score == 0) {
         // try to load the base backend
         for (const auto & search_path : search_paths) {
-            fs::path path = fs::path(search_path) / (backend_filename_prefix() + name + backend_filename_suffix());
+            fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native();
+            fs::path path = search_path / filename;
             if (fs::exists(path)) {
                 return get_reg().load_backend(path, silent);
             }
@@ -529,14 +564,6 @@ void ggml_backend_load_all() {
     ggml_backend_load_all_from_path(nullptr);
 }
 
-static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
-    try {
-        ggml_backend_load_best(name, silent, user_search_path);
-    } catch (const std::exception & e) {
-        GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
-    }
-}
-
 void ggml_backend_load_all_from_path(const char * dir_path) {
 #ifdef NDEBUG
     bool silent = true;
@@ -544,18 +571,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
     bool silent = false;
 #endif
 
-    ggml_backend_try_load_best("blas", silent, dir_path);
-    ggml_backend_try_load_best("cann", silent, dir_path);
-    ggml_backend_try_load_best("cuda", silent, dir_path);
-    ggml_backend_try_load_best("hip", silent, dir_path);
-    ggml_backend_try_load_best("kompute", silent, dir_path);
-    ggml_backend_try_load_best("metal", silent, dir_path);
-    ggml_backend_try_load_best("rpc", silent, dir_path);
-    ggml_backend_try_load_best("sycl", silent, dir_path);
-    ggml_backend_try_load_best("vulkan", silent, dir_path);
-    ggml_backend_try_load_best("opencl", silent, dir_path);
-    ggml_backend_try_load_best("musa", silent, dir_path);
-    ggml_backend_try_load_best("cpu", silent, dir_path);
+    ggml_backend_load_best("blas", silent, dir_path);
+    ggml_backend_load_best("cann", silent, dir_path);
+    ggml_backend_load_best("cuda", silent, dir_path);
+    ggml_backend_load_best("hip", silent, dir_path);
+    ggml_backend_load_best("kompute", silent, dir_path);
+    ggml_backend_load_best("metal", silent, dir_path);
+    ggml_backend_load_best("rpc", silent, dir_path);
+    ggml_backend_load_best("sycl", silent, dir_path);
+    ggml_backend_load_best("vulkan", silent, dir_path);
+    ggml_backend_load_best("opencl", silent, dir_path);
+    ggml_backend_load_best("musa", silent, dir_path);
+    ggml_backend_load_best("cpu", silent, dir_path);
     // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
     const char * backend_path = std::getenv("GGML_BACKEND_PATH");
     if (backend_path) {
diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp
index 65e150d63..dd11f304f 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp
@@ -21,6 +21,7 @@
 #include 
 #include 
 #include 
+#include 
 
 #ifdef __APPLE__
 #include 
@@ -125,11 +126,12 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
     return base;
 }
 
-void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
     // init_tensor is optional
     if (buffer->iface.init_tensor) {
-        buffer->iface.init_tensor(buffer, tensor);
+        return buffer->iface.init_tensor(buffer, tensor);
     }
+    return GGML_STATUS_SUCCESS;
 }
 
 void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -1641,7 +1643,7 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched,
 
 // utils
 
-void ggml_backend_view_init(struct ggml_tensor * tensor) {
+enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
     GGML_ASSERT(tensor->buffer == NULL);
     GGML_ASSERT(tensor->view_src != NULL);
     GGML_ASSERT(tensor->view_src->buffer != NULL);
@@ -1649,10 +1651,10 @@ void ggml_backend_view_init(struct ggml_tensor * tensor) {
 
     tensor->buffer = tensor->view_src->buffer;
     tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
-    ggml_backend_buffer_init_tensor(tensor->buffer, tensor);
+    return ggml_backend_buffer_init_tensor(tensor->buffer, tensor);
 }
 
-void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
     GGML_ASSERT(tensor->buffer == NULL);
     GGML_ASSERT(tensor->data == NULL);
     GGML_ASSERT(tensor->view_src == NULL);
@@ -1662,7 +1664,7 @@ void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor
 
     tensor->buffer = buffer;
     tensor->data = addr;
-    ggml_backend_buffer_init_tensor(buffer, tensor);
+    return ggml_backend_buffer_init_tensor(buffer, tensor);
 }
 
 static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
@@ -1708,7 +1710,8 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_
     struct ggml_tensor * dst = node_copies[id];
     if (dst->view_src != NULL) {
         graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src);
-        ggml_backend_view_init(dst);
+        enum ggml_status status = ggml_backend_view_init(dst);
+        GGML_ASSERT(status == GGML_STATUS_SUCCESS);
     }
     else {
         ggml_backend_tensor_copy(src, dst);
@@ -1823,7 +1826,6 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
     assert(g1->n_nodes == g2->n_nodes);
 
     for (int i = 0; i < g1->n_nodes; i++) {
-        //printf("eval %d/%d\n", i, g1->n_nodes);
         struct ggml_tensor * t1 = g1->nodes[i];
         struct ggml_tensor * t2 = g2->nodes[i];
 
diff --git a/ml/backend/ggml/ggml/src/ggml-common.h b/ml/backend/ggml/ggml/src/ggml-common.h
index 6c02b69ea..086c822d7 100644
--- a/ml/backend/ggml/ggml/src/ggml-common.h
+++ b/ml/backend/ggml/ggml/src/ggml-common.h
@@ -158,6 +158,12 @@ typedef sycl::half2 ggml_half2;
 
 #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
 
+#ifdef _MSC_VER
+#define GGML_EXTENSION
+#else // _MSC_VER
+#define GGML_EXTENSION __extension__
+#endif // _MSC_VER
+
 #define QK4_0 32
 typedef struct {
     ggml_half d;           // delta
@@ -167,7 +173,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 b
 
 #define QK4_1 32
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d; // delta
             ggml_half m; // min
@@ -188,7 +194,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0
 
 #define QK5_1 32
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d; // delta
             ggml_half m; // min
@@ -209,7 +215,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block
 
 #define QK8_1 32
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d; // delta
             ggml_half s; // d * sum(qs[i])
@@ -250,7 +256,7 @@ static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
     uint8_t qs[QK_K/4];      // quants
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d;    // super-block scale for quantized scales
             ggml_half dmin; // super-block scale for quantized mins
@@ -277,7 +283,7 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12
 // weight is represented as x = a * q + b
 // Effectively 4.5 bits per weight
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d;    // super-block scale for quantized scales
             ggml_half dmin; // super-block scale for quantized mins
@@ -294,7 +300,7 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2,
 // weight is represented as x = a * q + b
 // Effectively 5.5 bits per weight
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d;    // super-block scale for quantized scales
             ggml_half dmin; // super-block scale for quantized mins
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
index aa5ad5d8d..e73a3b69b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
@@ -23,6 +23,16 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
         ggml-cpu/amx/mmq.cpp
         ggml-cpu/amx/mmq.h
         ggml-cpu/ggml-cpu-impl.h
+        ggml-cpu/common.h
+        ggml-cpu/binary-ops.h
+        ggml-cpu/binary-ops.cpp
+        ggml-cpu/unary-ops.h
+        ggml-cpu/unary-ops.cpp
+        ggml-cpu/simd-mappings.h
+        ggml-cpu/vec.h
+        ggml-cpu/vec.cpp
+        ggml-cpu/ops.h
+        ggml-cpu/ops.cpp
         )
 
     target_compile_features(${GGML_CPU_NAME} PRIVATE c_std_11 cxx_std_17)
@@ -219,6 +229,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             if (GGML_AVX_VNNI)
                 list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
             endif()
+            if (GGML_BMI2)
+                # MSVC does not define macro __BMI2__
+                list(APPEND ARCH_DEFINITIONS __BMI2__ GGML_BMI2)
+            endif()
         else ()
             if (GGML_NATIVE)
                 list(APPEND ARCH_FLAGS -march=native)
@@ -233,6 +247,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                     list(APPEND ARCH_FLAGS -mfma)
                     list(APPEND ARCH_DEFINITIONS GGML_FMA)
                 endif()
+                if (GGML_BMI2)
+                    list(APPEND ARCH_FLAGS -mbmi2)
+                    list(APPEND ARCH_DEFINITIONS GGML_BMI2)
+                endif()
                 if (GGML_AVX)
                     list(APPEND ARCH_FLAGS -mavx)
                     list(APPEND ARCH_DEFINITIONS GGML_AVX)
@@ -279,21 +297,31 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 endif()
             endif()
         endif()
-    elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
+    elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
         message(STATUS "PowerPC detected")
-        execute_process(COMMAND bash -c "grep POWER10 /proc/cpuinfo | head -n 1" OUTPUT_VARIABLE POWER10_M)
-        string(FIND "${POWER10_M}" "POWER10" substring_index)
-        if (NOT DEFINED substring_index OR "${substring_index}" STREQUAL "")
-            set(substring_index -1)
-        endif()
+        if (GGML_NATIVE)
+            if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
+                file(READ "/proc/cpuinfo" POWER10_M)
+            elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
+                execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
+            endif()
 
-        if (${substring_index} GREATER_EQUAL 0)
-        list(APPEND ARCH_FLAGS -mcpu=power10)
-        elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
-        list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
+            string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
+            string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
+
+            if (EXTRACTED_NUMBER GREATER_EQUAL 10)
+                list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
+            elseif (EXTRACTED_NUMBER EQUAL 9)
+                list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
+            elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
+                list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
+            else()
+                list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
+            endif()
         else()
-            list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
-            # TODO: Add  targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
+            if (GGML_CPU_POWERPC_CPUTYPE)
+                list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE})
+            endif()
         endif()
     elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
         message(STATUS "loongarch64 detected")
@@ -308,7 +336,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
     elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64")
         message(STATUS "RISC-V detected")
         if (GGML_RVV)
-            list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
+            if (GGML_RV_ZFH)
+                list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -DGGML_RV_ZFH -mabi=lp64d)
+            else()
+                list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
+            endif()
         endif()
     elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
         message(STATUS "s390x detected")
@@ -347,9 +379,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
 
         # Fetch KleidiAI sources:
         include(FetchContent)
-        set(KLEIDIAI_COMMIT_TAG "v1.3.0")
+        set(KLEIDIAI_COMMIT_TAG "v1.5.0")
         set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
-        set(KLEIDIAI_ARCHIVE_MD5  "060bd2dc64642b091f461cc8dd7426d9")
+        set(KLEIDIAI_ARCHIVE_MD5  "ea22e1aefb800e9bc8c74d91633cc58e")
 
         if (POLICY CMP0135)
             cmake_policy(SET CMP0135 NEW)
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp
index 5ec5263ce..0f067137d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp
@@ -50,10 +50,11 @@ static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
     return (void *) (buffer->context);
 }
 
-static void ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+static enum ggml_status ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
     tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor);
 
     GGML_UNUSED(buffer);
+    return GGML_STATUS_SUCCESS;
 }
 
 static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/binary-ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/binary-ops.cpp
new file mode 100644
index 000000000..14f5b43ae
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/binary-ops.cpp
@@ -0,0 +1,158 @@
+#include "binary-ops.h"
+
+#if defined(GGML_USE_ACCELERATE)
+#include 
+
+using vDSP_fn_t = void (*)(const float *, vDSP_Stride, const float *, vDSP_Stride, float *, vDSP_Stride, vDSP_Length);
+#endif
+
+static inline float op_add(float a, float b) {
+    return a + b;
+}
+
+static inline float op_sub(float a, float b) {
+    return a - b;
+}
+
+static inline float op_mul(float a, float b) {
+    return a * b;
+}
+
+static inline float op_div(float a, float b) {
+    return a / b;
+}
+
+template 
+static inline void vec_binary_op_contiguous(const int64_t n, dst_t * z, const src0_t * x, const src1_t * y) {
+    constexpr auto src0_to_f32 = type_conversion_table::to_f32;
+    constexpr auto src1_to_f32 = type_conversion_table::to_f32;
+    constexpr auto f32_to_dst  = type_conversion_table::from_f32;
+
+    for (int i = 0; i < n; i++) {
+        z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(y[i])));
+    }
+}
+
+template 
+static inline void vec_binary_op_non_contiguous(const int64_t n, const int64_t ne10, const int64_t nb10, dst_t * z, const src0_t * x, const src1_t * y) {
+    constexpr auto src0_to_f32 = type_conversion_table::to_f32;
+    constexpr auto src1_to_f32 = type_conversion_table::to_f32;
+    constexpr auto f32_to_dst  = type_conversion_table::from_f32;
+
+    for (int i = 0; i < n; i++) {
+        int i10 = i % ne10;
+        const src1_t * y_ptr = (const src1_t *)((const char *)y + i10*nb10);
+        z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(*y_ptr)));
+    }
+}
+
+template 
+static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(dst_t));
+    GGML_ASSERT(nb00 == sizeof(src0_t));
+
+    const auto [ir0, ir1] = get_thread_range(params, src0);
+    const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
+
+    if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
+        GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    }
+
+#ifdef GGML_USE_ACCELERATE
+    vDSP_fn_t vDSP_op = nullptr;
+    // TODO - avoid the f32-only check using type 'trait' lookup tables and row-based src-to-float conversion functions
+    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        if (op == op_add) {
+            vDSP_op = vDSP_vadd;
+        } else if (op == op_sub) {
+            vDSP_op = vDSP_vsub;
+        } else if (op == op_mul) {
+            vDSP_op = vDSP_vmul;
+        } else if (op == op_div) {
+            vDSP_op = vDSP_vdiv;
+        }
+    }
+#endif
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir/(ne02*ne01);
+        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        const int64_t i13 = i03 % ne13;
+        const int64_t i12 = i02 % ne12;
+        const int64_t i11 = i01 % ne11;
+
+        dst_t        * dst_ptr  = (dst_t  *)       ((char *)       dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+        const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+        const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+        if (is_src1_contiguous) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t nr0 = ne00 / ne10;
+
+            for (int64_t r = 0; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+                if constexpr (std::is_same_v && std::is_same_v && std::is_same_v) {
+                    if (vDSP_op != nullptr) {
+                        vDSP_op(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
+                        continue;
+                    }
+                }
+#endif
+                vec_binary_op_contiguous(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+            }
+        } else {
+            vec_binary_op_non_contiguous(ne0, ne10, nb10, dst_ptr, src0_ptr, src1_ptr);
+        }
+    }
+}
+
+// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
+template 
+static void binary_op(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    /*  */ if (src0->type == GGML_TYPE_F32  && src1->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32
+        apply_binary_op(params, dst);
+    } else if (src0->type == GGML_TYPE_F16  && src1->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16
+        apply_binary_op(params, dst);
+    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
+        apply_binary_op(params, dst);
+    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_BF16) {
+        apply_binary_op(params, dst);
+    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) {
+        apply_binary_op(params, dst);
+    } else if (src0->type == GGML_TYPE_F16  && src1->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F16) {
+        apply_binary_op(params, dst);
+    } else if (src0->type == GGML_TYPE_F16  && src1->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) {
+        apply_binary_op(params, dst);
+    } else {
+        GGML_ABORT("%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+    }
+}
+
+void ggml_compute_forward_add_non_quantized(const ggml_compute_params * params, ggml_tensor * dst) {
+    binary_op(params, dst);
+}
+
+void ggml_compute_forward_sub(const ggml_compute_params * params, ggml_tensor * dst) {
+    binary_op(params, dst);
+}
+
+void ggml_compute_forward_mul(const ggml_compute_params * params, ggml_tensor * dst) {
+    binary_op(params, dst);
+}
+
+void ggml_compute_forward_div(const ggml_compute_params * params, ggml_tensor * dst) {
+    binary_op(params, dst);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/binary-ops.h b/ml/backend/ggml/ggml/src/ggml-cpu/binary-ops.h
new file mode 100644
index 000000000..aca1d89be
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/binary-ops.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include "common.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void ggml_compute_forward_add_non_quantized(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_sub(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_mul(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_div(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/common.h b/ml/backend/ggml/ggml/src/ggml-cpu/common.h
new file mode 100644
index 000000000..3df01c1ed
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/common.h
@@ -0,0 +1,72 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-cpu-traits.h"
+#include "ggml-cpu-impl.h"
+#include "ggml-impl.h"
+
+#ifdef __cplusplus
+
+#include 
+
+// convenience functions/macros for use in template calls
+// note: these won't be required after the 'traits' lookup table is used.
+static inline ggml_fp16_t f32_to_f16(float x) {
+    return GGML_FP32_TO_FP16(x);
+}
+
+static inline float f16_to_f32(ggml_fp16_t x) {
+    return GGML_FP16_TO_FP32(x);
+}
+
+static inline ggml_bf16_t f32_to_bf16(float x) {
+    return GGML_FP32_TO_BF16(x);
+}
+
+static inline float bf16_to_f32(ggml_bf16_t x) {
+    return GGML_BF16_TO_FP32(x);
+}
+
+static inline float f32_to_f32(float x) {
+    return x;
+}
+
+// TODO - merge this into the traits table, after using row-based conversions
+template 
+struct type_conversion_table;
+
+template <>
+struct type_conversion_table {
+    static constexpr float (*to_f32)(ggml_fp16_t) = f16_to_f32;
+    static constexpr ggml_fp16_t (*from_f32)(float) = f32_to_f16;
+};
+
+template <>
+struct type_conversion_table {
+    static constexpr float (*to_f32)(float) = f32_to_f32;
+    static constexpr float (*from_f32)(float) = f32_to_f32;
+};
+
+template <>
+struct type_conversion_table {
+    static constexpr float (*to_f32)(ggml_bf16_t) = bf16_to_f32;
+    static constexpr ggml_bf16_t (*from_f32)(float) = f32_to_bf16;
+};
+
+static std::pair get_thread_range(const struct ggml_compute_params * params, const struct ggml_tensor * src0) {
+    const int64_t ith = params->ith;
+    const int64_t nth = params->nth;
+
+    const int64_t nr  = ggml_nrows(src0);
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    return {ir0, ir1};
+}
+
+#endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/cpu-feats-x86.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/cpu-feats-x86.cpp
index e8133d411..902ee4346 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/cpu-feats-x86.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/cpu-feats-x86.cpp
@@ -278,6 +278,10 @@ static int ggml_backend_cpu_x86_score() {
     if (!is.SSE42()) { return 0; }
     score += 1<<2;
 #endif
+#ifdef GGML_BMI2
+    if (!is.BMI2()) { return 0; }
+    score += 1<<3;
+#endif
 #ifdef GGML_AVX
     if (!is.AVX()) { return 0; }
     score += 1<<4;
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
index b311a5b1c..74a31abb2 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
@@ -45,6 +45,24 @@ using block_q4_0x8 = block<4, 8>;
 using block_q8_0x4 = block<8, 4>;
 using block_q8_0x8 = block<8, 8>;
 
+
+struct block_q4_Kx8 {
+    ggml_half d[8];      // super-block scale for quantized scales
+    ggml_half dmin[8];   // super-block scale for quantized mins
+    uint8_t scales[96];  // scales and mins, quantized with 6 bits
+    uint8_t qs[1024];    // 4--bit quants
+};
+
+static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
+
+struct block_q8_Kx4 {
+    float d[4];              // delta
+    int8_t qs[QK_K * 4];     // quants
+    int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
+};
+
+static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
+
 struct block_iq4_nlx4 {
     ggml_half d[4];            // deltas for 4 iq4_nl blocks
     uint8_t   qs[QK4_NL * 2];  // nibbles / quants for 4 iq4_nl blocks
@@ -60,6 +78,13 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wro
 
 #define UNUSED GGML_UNUSED
 
+static inline int nearest_int(float fval) {
+    assert(fabsf(fval) <= 4194303.f);
+    float val = fval + 12582912.f;
+    int i; memcpy(&i, &val, sizeof(int));
+    return (i & 0x007fffff) - 0x00400000;
+}
+
 // Functions to create the interleaved data layout formats
 
 // interleave 4 block_q4_0s in blocks of blck_size_interleave
@@ -225,7 +250,7 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
 
 static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
 
-static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+static void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(QK8_0 == 32);
     assert(k % QK8_0 == 0);
     const int nb = k / QK8_0;
@@ -319,7 +344,7 @@ static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRIC
 #endif
 }
 
-static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(QK8_0 == 32);
     assert(k % QK8_0 == 0);
     const int nb = k / QK8_0;
@@ -534,16 +559,289 @@ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
 #endif
 }
 
-static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
+static void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+    assert(QK_K == 256);
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
+
+#if defined(__AVX2__)
+    float iscale[4];
+    __m256 srcv[4][32];
+    __m256 iscale_vec[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            // Load elements into 4 AVX vectors
+            __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
+            __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
+            __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
+            __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
+
+            // Compute max(abs(e)) for the block
+            const __m256 signBit = _mm256_set1_ps( -0.0f );
+            __m256 abs0 = _mm256_andnot_ps( signBit, v0 );
+            __m256 abs1 = _mm256_andnot_ps( signBit, v1 );
+            __m256 abs2 = _mm256_andnot_ps( signBit, v2 );
+            __m256 abs3 = _mm256_andnot_ps( signBit, v3 );
+
+            __m256 maxAbs = _mm256_max_ps( abs0, abs1 );
+            maxAbs = _mm256_max_ps( maxAbs, abs2 );
+            maxAbs = _mm256_max_ps( maxAbs, abs3 );
+
+            __m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
+            __m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
+            __m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
+            __m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
+
+            __m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
+
+            srcv[row_iter][0] = v0;
+            srcv[row_iter][1] = v1;
+            srcv[row_iter][2] = v2;
+            srcv[row_iter][3] = v3;
+
+            for (int sb = 1; sb < 8; sb++) {
+                // Temporarily stores absolute quant values
+                __m256 tempAbs = maxAbs;
+
+                // Load elements into 4 AVX vectors
+                __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
+                __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
+                __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
+                __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
+
+                // Compute max(abs(e)) for the block
+                __m256 abs0 = _mm256_andnot_ps( signBit, v0 );
+                __m256 abs1 = _mm256_andnot_ps( signBit, v1 );
+                __m256 abs2 = _mm256_andnot_ps( signBit, v2 );
+                __m256 abs3 = _mm256_andnot_ps( signBit, v3 );
+
+                maxAbs = _mm256_max_ps( maxAbs, abs0 );
+                maxAbs = _mm256_max_ps( maxAbs, abs1 );
+                maxAbs = _mm256_max_ps( maxAbs, abs2 );
+                maxAbs = _mm256_max_ps( maxAbs, abs3 );
+
+                __m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
+                maskAbs = _mm256_and_ps( maskAbs, mask_prev );
+
+                mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
+                mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
+                mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
+                mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
+
+                __m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
+                maskAbs =  _mm256_or_ps(maskAbs, mask_curr);
+
+                srcv[row_iter][sb * 4] = v0;
+                srcv[row_iter][sb * 4 + 1] = v1;
+                srcv[row_iter][sb * 4 + 2] = v2;
+                srcv[row_iter][sb * 4 + 3] = v3;
+            }
+
+            __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+            max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+            max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+            const float maxScalar = _mm_cvtss_f32( max4 );
+
+            __m256 maxScalarVec = _mm256_set1_ps(maxScalar);
+
+            __m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
+            __m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
+
+            const int mask = _mm256_movemask_ps(finalMask);
+            iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+
+            if(mask) {
+                iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
+            }
+
+            y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
+            iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
+        }
+
+        __m256i quants_interleaved[32];
+        for (int j = 0; j < 32; j++) {
+            // Apply the multiplier
+            __m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
+            __m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
+            __m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
+            __m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
+
+            // Round to nearest integer
+            v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+            v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+            v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+            v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+            // Convert floats to integers
+            __m256i i0 = _mm256_cvtps_epi32( v0 );
+            __m256i i1 = _mm256_cvtps_epi32( v1 );
+            __m256i i2 = _mm256_cvtps_epi32( v2 );
+            __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+            // Convert int32 to int16
+            i0 = _mm256_packs_epi32( i0, i1 );
+            i2 = _mm256_packs_epi32( i2, i3 );
+            // Convert int16 to int8
+            i0 = _mm256_packs_epi16( i0, i2 );
+
+            //  Permute and store the quantized weights in the required order after the pack instruction
+            const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+            i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+            _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
+            quants_interleaved[j] = i0;
+        }
+
+        // Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
+        __m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
+        shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
+        __m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
+        shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
+        __m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
+        shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
+
+        for (int k = 0; k < 4; k++) {
+            // Quants from four different sub blocks are taken
+            __m256i q0 = quants_interleaved[k * 8 + 0];
+            __m256i q1 = quants_interleaved[k * 8 + 1];
+            __m256i q2 = quants_interleaved[k * 8 + 2];
+            __m256i q3 = quants_interleaved[k * 8 + 3];
+            __m256i q4 = quants_interleaved[k * 8 + 4];
+            __m256i q5 = quants_interleaved[k * 8 + 5];
+            __m256i q6 = quants_interleaved[k * 8 + 6];
+            __m256i q7 = quants_interleaved[k * 8 + 7];
+
+
+            // The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
+            __m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
+            __m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
+            __m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
+            sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
+            __m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
+            sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
+
+            __m256i one = _mm256_set1_epi8(1);
+            __m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
+
+            for (int l = 0; l < 3; l++) {
+                // Quants value shifted to process next two values from each sub block
+                q0 = _mm256_srli_epi64(q0, 16);
+                q2 = _mm256_srli_epi64(q2, 16);
+                q4 = _mm256_srli_epi64(q4, 16);
+                q6 = _mm256_srli_epi64(q6, 16);
+
+                sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
+                sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
+                sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
+                sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
+                sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
+                sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
+
+                bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
+            }
+
+            // The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
+            __m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
+            __m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
+            __m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
+            sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
+            __m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
+            sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
+
+            __m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
+
+            for (int l = 0; l < 3; l++) {
+                // Quants value shifted to process next two values from each sub block
+                q1 = _mm256_srli_epi64(q1, 16);
+                q3 = _mm256_srli_epi64(q3, 16);
+                q5 = _mm256_srli_epi64(q5, 16);
+                q7 = _mm256_srli_epi64(q7, 16);
+
+                sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
+                sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
+                sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
+                sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
+                sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
+                sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
+
+                bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
+            }
+
+            // Overall bsums in interleaved fashion computed by adding results of both halves
+            __m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
+            _mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
+        }
+    }
+
+#else
+
+    // scalar
+    const int blck_size_interleave = 8;
+    float srcv[4][QK_K];
+    float iscale[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+            float max = 0;
+
+            for (int j = 0; j < QK_K; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
+                // Update the maximum value of the corresponding super block
+                if(amax < fabsf(srcv[row_iter][j])) {
+                    amax = fabsf(srcv[row_iter][j]);
+                    max = srcv[row_iter][j];
+                }
+            }
+
+            iscale[row_iter] = amax ? -127.f/max : 0;
+
+            y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
+        }
+
+        for (int j = 0; j < QK_K / 4; j++) {
+            y[i].bsums[j] = 0;
+        }
+
+        // Quants values are interleaved in sequence of eight bytes from corresponding super blocks
+        // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
+        // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
+        for (int j = 0; j < QK_K * 4; j++) {
+            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+            int src_id     = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            src_offset += (j % blck_size_interleave);
+            int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
+
+            float x0 = srcv[src_id][src_offset] * iscale[src_id];
+            y[i].qs[j] = nearest_int(x0);
+            y[i].bsums[index] += y[i].qs[j];
+        }
+    }
+#endif
+}
+
+template 
+void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
+
+template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
     assert(nrow == 4);
     UNUSED(nrow);
-    if (blck_size_interleave == 4) {
-        quantize_q8_0_4x4(x, vy, n_per_row);
-    } else if (blck_size_interleave == 8) {
-        quantize_q8_0_4x8(x, vy, n_per_row);
-    } else {
-        assert(false);
-    }
+    ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
+}
+
+template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
+    assert(nrow == 4);
+    UNUSED(nrow);
+    ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
+}
+
+template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
+    assert(nrow == 4);
+    UNUSED(nrow);
+    ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
 }
 
 static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -994,6 +1292,281 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
     }
 }
 
+static void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__AVX2__)
+    // Lookup table to convert signed nibbles to signed bytes
+    __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+    // Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
+    __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
+    __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
+    // Permute mask used for easier vector processing at later stages
+    __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
+
+    // Mask to extract nibbles from bytes
+    const __m256i m4b = _mm256_set1_epi8(0x0F);
+
+    int64_t b_nb = n / QK_K;
+
+    const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 *)vx;
+    const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
+
+    // Process Q8_K blocks one by one
+    for (int64_t y = 0; y < nr; y++) {
+
+        // Pointers to LHS blocks of block_q8_K format
+        const block_q8_K * a_ptr = a_ptr_start + (y * nb);
+
+        // Take group of eight interleaved block_q4_K structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < nc / 8; x++) {
+
+            // Pointers to RHS blocks
+            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
+
+            // Master FP accumulators
+            __m256 acc_row = _mm256_setzero_ps();
+            __m256 acc_min_rows = _mm256_setzero_ps();
+
+            for (int64_t b = 0; b < nb; b++) {
+
+                // Load and convert to FP32 scale from block_q8_K
+                const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));
+
+                // Load the scale values for the 8 blocks interleaved in block_q4_Kx8
+                // col_scale_f32 rearranged so as to multiply with appropriate quants
+                const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
+                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
+
+                __m256i iacc_b = _mm256_setzero_si256();
+                __m256i iacc_min_b = _mm256_setzero_si256();
+
+                const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));
+                __m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
+                q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
+
+                // Processes two sub blocks from each Q4_K in each iteration
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+
+                    // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
+                    const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
+
+                    // 4-bit -> 8-bit
+                    // Values of the first sub block of eight block_q4_K structures for the sb loop
+                    const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
+                    const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
+                    const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
+                    const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
+                    const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
+                    const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
+                    const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
+                    const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
+
+                    // Values of the second sub block of eight block_q4_K structures when sb = 1
+                    const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
+                    const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
+                    const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
+                    const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
+                    const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
+                    const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
+                    const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
+                    const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
+
+                    uint32_t utmp_0[4], utmp_1[4];
+
+                    // Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
+                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_0 = utmp_0[1] & kmask1;
+                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
+                    utmp_0[2] = uaux_0;
+                    utmp_0[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
+                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_1 = utmp_1[1] & kmask1;
+                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
+                    utmp_1[2] = uaux_1;
+                    utmp_1[0] &= kmask1;
+
+                    // Scales of first sub block in the sb loop
+                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
+                    __m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
+                    __m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
+
+                    // Scales of second sub block in the sb loop
+                    __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
+                    __m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
+                    __m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
+
+                    // Mins of first and second sub block of Q4_K block are arranged side by side
+                    __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
+
+                    // Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
+                    __m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));
+                    __m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));
+                    __m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));
+                    __m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));
+
+                    lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
+                    lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
+                    lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
+                    lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
+
+                    // Dot product done within 32 bit lanes and accumulated in the same vector
+                    // First done for first sub block and thenn for second sub block in each sb
+                    // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
+                    // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
+                    // ...........................................................................
+                    // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
+
+
+                    __m256i iacc_0 = _mm256_setzero_si256();
+                    __m256i iacc_1 = _mm256_setzero_si256();
+
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
+
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
+
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
+
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
+
+                    iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
+
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
+
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
+
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
+
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
+
+                    iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
+
+                    // Accumulate the iacc value for one sb
+                    __m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
+
+                    // Broadcast the bsums of the two sub blocks  of the iteration of Q8_K across the vector
+                    // Multiply-Add with corresponding mins of Q4_Kx8 with bsums
+                    __m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
+                    __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
+                    q8s = _mm256_bsrli_epi128(q8s, 4);
+
+                    // Accumulate for the complete block
+                    iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
+                    iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
+                }
+
+                // Multiply-Add with scale values for the complete super block
+                acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
+                acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
+
+            }
+
+            // Accumulated output values permuted so as to be stored in appropriate order post accumulation
+            acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
+            _mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
+        }
+    }
+
+#else
+
+    float sumf[8];
+    float sum_minf[8];
+    uint32_t utmp[32];
+    int sumi1;
+    int sumi2;
+    int sumi;
+
+    const block_q8_K * a_ptr = (const block_q8_K *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0;
+            sum_minf[j] = 0.0;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int sb = 0; sb < 8; sb++) {
+                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
+                utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
+                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
+                utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
+                utmp[sb * 4 + 2] = uaux_0;
+                utmp[sb * 4 + 0] &= kmask1;
+            }
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
+                uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi1 = 0;
+                    sumi2 = 0;
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
+                        sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);
+                        sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);
+                        sumi1 = sumi1 * scales_0[j];
+                        sumi2 = sumi2 * scales_1[j];
+                        sumi += sumi1 + sumi2;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
+                }
+            }
+            for (int sb = 0; sb < 8; sb++) {
+                uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
+        }
+    }
+#endif
+}
+
+
 static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -3480,6 +4053,781 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
     }
 }
 
+static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__AVX2__)
+    const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx;
+    const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;
+    int64_t b_nb = n / QK_K;
+    int64_t y = 0;
+
+    // Mask to mask out nibbles from packed bytes
+    const __m256i m4b = _mm256_set1_epi8(0x0F);
+    // Permute mask used for easier vector processing at later stages
+    __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
+
+    int anr = nr - nr % 16;; // Used to align nr with boundary of 16
+    // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
+    for (; y < anr / 4; y += 4) {
+
+        const block_q8_Kx4 * a_ptrs[4];
+
+        a_ptrs[0] = a_ptr_start + (y * nb);
+        for (int i = 0; i < 3; ++i) {
+            a_ptrs[i + 1] = a_ptrs[i] + nb;
+        }
+
+        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < nc / 8; x++) {
+
+            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
+
+            // Master FP accumulators
+            __m256 acc_rows[16];
+            for (int i = 0; i < 16; i++) {
+                acc_rows[i] = _mm256_setzero_ps();
+            }
+
+            __m256 acc_min_rows[16];
+            for (int i = 0; i < 16; i++) {
+                acc_min_rows[i] = _mm256_setzero_ps();
+            }
+
+            // For super block
+            for (int64_t b = 0; b < nb; b++) {
+
+                // Scale values - Load the eight scale values of block_q4_kx8
+                const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+
+                // dmin values - Load the eight dmin values of block_q4_kx8
+                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
+
+                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+
+                    // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
+
+                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
+                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
+
+                    // 4-bit -> 8-bit
+                    // First sub block of the two sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
+                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
+
+                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
+                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
+
+                    const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
+                    const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
+
+                    const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
+                    const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
+
+                    // Second sub block of the two sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
+                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
+
+                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
+                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
+
+                    const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
+                    const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
+
+                    const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
+                    const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
+                    const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
+
+                    const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
+                    const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
+
+                    const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
+                    const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
+
+                    const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
+                    const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
+
+                    const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
+                    const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
+
+                    const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
+                    const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
+
+                    const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
+                    const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
+
+                    const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
+                    const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
+
+
+                    // Shuffle pattern two - right side input
+                    const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
+                    const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
+
+                    const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
+                    const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
+
+                    const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
+                    const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
+
+                    const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
+                    const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
+
+                    const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
+                    const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
+
+                    const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
+                    const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
+
+                    const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
+                    const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
+
+                    const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
+                    const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
+
+                    uint32_t utmp_0[4], utmp_1[4];
+
+                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
+                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_0 = utmp_0[1] & kmask1;
+                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
+                    utmp_0[2] = uaux_0;
+                    utmp_0[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
+                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_1 = utmp_1[1] & kmask1;
+                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
+                    utmp_1[2] = uaux_1;
+                    utmp_1[0] &= kmask1;
+
+                    // Scales of first sub block in the sb loop
+                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
+                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+
+                    // Scales of second sub block in the sb loop
+                    const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
+                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+
+                    // Mins of first and second sub block of Q4_K block are arranged side by side
+                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
+
+                    const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
+                    const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
+
+                    const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
+                    const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
+
+                    for (int rp = 0; rp < 4; rp++) {
+
+                        // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
+                        // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                        __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
+                        __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
+                        __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
+                        __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
+                        __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
+                        __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
+                        __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
+                        __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
+                        __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
+                        __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
+                        __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
+                        __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
+                        __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
+                        __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
+                        __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
+                        __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
+                        __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
+                        __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
+                        __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
+                        __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
+                        __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
+                        __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
+                        __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
+                        __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
+
+                        // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
+                        __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
+                        __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
+                        lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
+
+                        // Shuffle pattern one - left side input
+                        const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
+                        const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
+
+                        const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160);  //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
+                        const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160);  //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
+
+                        const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160);  //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
+                        const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160);  //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
+
+                        const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160);  //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
+                        const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
+
+                        const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
+                        const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
+
+                        const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160);  //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
+                        const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160);  //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
+
+                        const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160);  //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
+                        const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160);  //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
+
+                        const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
+                        const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
+
+                        // Shuffle pattern two- left side input
+                        const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
+                        const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
+
+                        const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
+                        const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
+
+                        const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
+                        const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
+
+                        const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
+                        const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
+
+                        const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
+                        const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
+
+                        const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
+                        const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
+
+                        const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
+                        const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
+
+                        const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
+                        const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
+
+                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                        __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
+                        __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
+                        __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
+                        __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
+                        __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
+                        __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
+                        __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
+                        __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
+
+                        __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
+                        __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
+                        __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
+                        __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
+                        __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
+                        __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
+                        __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
+                        __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
+                        __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
+                        __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
+                        __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
+
+                        __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
+                        __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
+                        __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
+                        __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
+                        iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
+                        iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
+                        iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
+
+                        iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
+                        iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
+                        iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
+                        iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
+
+                        // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
+                        __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
+                        __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
+                        __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
+                        __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
+                        __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
+                        __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
+                        __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
+                        __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
+
+                        __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
+                        __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
+                        __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
+                        __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
+
+                        // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
+                        const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
+                        const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);//GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+
+                        // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                        acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+                        acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+                        acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+                        acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+
+                        __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
+                        __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
+                        __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
+                        __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
+
+                        acc_min_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
+                        acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
+                        acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
+                        acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
+
+                    }
+                }
+            }
+            // Store the accumulated values
+            for (int i = 0; i < 16; i++) {
+                _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
+            }
+        }
+    }
+    for (; y < nr / 4; y++) {
+
+        const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
+
+        for (int64_t x = 0; x < nc / 8; x++) {
+
+            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
+
+            // Master FP accumulators
+            __m256 acc_rows[4];
+            for (int i = 0; i < 4; i++) {
+                acc_rows[i] = _mm256_setzero_ps();
+            }
+
+            __m256 acc_min_rows[4];
+            for (int i = 0; i < 4; i++) {
+                acc_min_rows[i] = _mm256_setzero_ps();
+            }
+
+            for (int64_t b = 0; b < nb; b++) {
+
+                // Scale values - Load the eight scale values of block_q4_Kx8
+                const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+
+                // dmin values - Load the eight dmin values of block_q4_Kx8
+                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
+
+                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+
+                    // Load the eight block_q4_k for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
+
+                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
+                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
+
+                    // 4-bit -> 8-bit
+                    // First sub block of the two sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
+                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
+
+                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
+                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
+
+                    const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
+                    const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
+
+                    const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
+                    const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
+
+                    // Second sub block of the two sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
+                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
+
+                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
+                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
+
+                    const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
+                    const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
+
+                    const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
+                    const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
+                    const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
+
+                    const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
+                    const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
+
+                    const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
+                    const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
+
+                    const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
+                    const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
+
+                    const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
+                    const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
+
+                    const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
+                    const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
+
+                    const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
+                    const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
+
+                    const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
+                    const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
+
+                    // Shuffle pattern two - right side input
+                    const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
+                    const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
+
+                    const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
+                    const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
+
+                    const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
+                    const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
+
+                    const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
+                    const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
+
+                    const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
+                    const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
+
+                    const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
+                    const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
+
+                    const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
+                    const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
+
+                    const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
+                    const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
+
+                    uint32_t utmp_0[4], utmp_1[4];
+
+                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
+                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_0 = utmp_0[1] & kmask1;
+                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
+                    utmp_0[2] = uaux_0;
+                    utmp_0[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures when sb = 1
+                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
+                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_1 = utmp_1[1] & kmask1;
+                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
+                    utmp_1[2] = uaux_1;
+                    utmp_1[0] &= kmask1;
+
+                    // Scales of first sub block in the sb loop
+                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
+                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+
+                    // Scales of second sub block in the sb loop
+                    const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
+                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+
+                    // Mins of first and second sub block of Q4_K block are arranged side by side
+                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
+
+                    const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
+                    const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
+
+                    const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
+                    const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
+
+                    // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
+                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                    __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
+                    __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
+                    __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
+                    __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
+                    __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
+                    __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
+                    __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
+                    __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
+                    __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
+                    __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
+                    __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
+                    __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
+                    __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
+                    __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
+                    __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
+                    __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
+                    __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
+                    __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
+                    __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
+                    __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
+                    __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
+                    __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
+                    __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
+                    __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
+
+                    // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
+                    __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
+                    __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
+                    lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
+
+                    // Shuffle pattern one - left side input
+                    const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
+                    const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
+
+                    const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160);  //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
+                    const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160);  //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
+
+                    const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160);  //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
+                    const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160);  //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
+
+                    const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160);  //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
+                    const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
+
+                    const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
+                    const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
+
+                    const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160);  //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
+                    const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160);  //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
+
+                    const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160);  //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
+                    const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160);  //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
+
+                    const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
+                    const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
+
+                    // Shuffle pattern two- left side input
+                    const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
+                    const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
+
+                    const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
+                    const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
+
+                    const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
+                    const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
+
+                    const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
+                    const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
+
+                    const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
+                    const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
+
+                    const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
+                    const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
+
+                    const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
+                    const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
+
+                    const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
+                    const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
+
+                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                    __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
+                    __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
+                    __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
+                    __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
+                    __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
+                    __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
+                    __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
+                    __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
+
+                    __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
+                    __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
+                    __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
+                    __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
+                    __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
+                    __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
+                    __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
+                    __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
+
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
+                    __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
+                    __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
+                    __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
+
+                    __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
+                    __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
+                    __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
+                    __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
+
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
+                    iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
+                    iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
+                    iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
+
+                    iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
+                    iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
+                    iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
+                    iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
+
+                    // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
+                    __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
+                    __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
+                    __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
+                    __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
+                    __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
+                    __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
+                    __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
+                    __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
+
+                    __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
+                    __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
+                    __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
+                    __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
+
+                    // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
+                    const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
+                    const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); //GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+
+                    // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                    acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
+                    acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
+                    acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+                    acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+
+                    __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
+                    __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
+                    __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
+                    __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
+
+                    acc_min_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
+                    acc_min_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
+                    acc_min_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);
+                    acc_min_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
+                }
+            }
+
+            // Store the accumulated values
+            for (int i = 0; i < 4; i++) {
+                _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
+            }
+        }
+    }
+
+#else
+
+    float sumf[4][8];
+    float sum_minf[4][8];
+    uint32_t utmp[32];
+    int sumi1;
+    int sumi2;
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0;
+                    sum_minf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int sb = 0; sb < 8; sb++) {
+                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
+                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
+                    utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
+                    utmp[sb * 4 + 2] = uaux_0;
+                    utmp[sb * 4 + 0] &= kmask1;
+                }
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
+                    uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi1 = 0;
+                            sumi2 = 0;
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
+                                sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);
+                                sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
+                                sumi1 = sumi1 * scales_0[j];
+                                sumi2 = sumi2 * scales_1[j];
+                                sumi += sumi1 + sumi2;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+                for (int sb = 0; sb < 8; sb++) {
+                    uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
+                    for(int m = 0; m < 4; m++) {
+                        const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
+                        for(int j = 0; j < ncols_interleaved; j++) {
+                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
+                }
+            }
+        }
+    }
+#endif
+}
+
 static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -3660,6 +5008,82 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
     return out;
 }
 
+static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
+    block_q4_Kx8 out;
+    //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
+    }
+
+    for (int i = 0; i < 8; i++) {
+        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
+    }
+
+    const int end = QK_K * 4 / blck_size_interleave;
+
+    // Interleave Q4_K quants by taking 8 bytes at a time
+    for (int i = 0; i < end; ++i) {
+        int src_id = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        uint64_t elems;
+        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
+        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+    }
+
+    // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
+    // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
+    // The output Q4_Kx8 structure has 96 bytes
+    // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure
+    // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures
+    uint8_t s[8], m[8];
+
+    for (int i = 0; i < 4; i++) {
+        for (int j = 0; j < 8; j++) {
+            s[j] = in[j].scales[i] & 63;
+            m[j] = in[j].scales[i + 4] & 63;
+        }
+
+        out.scales[i * 12]      = (s[0] & 63) + ((s[4] & 48) << 2);
+        out.scales[i * 12 + 1]  = (s[1] & 63) + ((s[5] & 48) << 2);
+        out.scales[i * 12 + 2]  = (s[2] & 63) + ((s[6] & 48) << 2);
+        out.scales[i * 12 + 3]  = (s[3] & 63) + ((s[7] & 48) << 2);
+        out.scales[i * 12 + 4]  = (m[0] & 63) + ((m[4] & 48) << 2);
+        out.scales[i * 12 + 5]  = (m[1] & 63) + ((m[5] & 48) << 2);
+        out.scales[i * 12 + 6]  = (m[2] & 63) + ((m[6] & 48) << 2);
+        out.scales[i * 12 + 7]  = (m[3] & 63) + ((m[7] & 48) << 2);
+        out.scales[i * 12 + 8]  = (s[4] & 15) + ((m[4] & 15) << 4);
+        out.scales[i * 12 + 9]  = (s[5] & 15) + ((m[5] & 15) << 4);
+        out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
+        out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
+
+    }
+
+    for (int i = 0; i < 4; i++) {
+        for (int j = 0; j < 8; j++) {
+            s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);
+            m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);
+        }
+
+        out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
+        out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
+        out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
+        out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
+        out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
+        out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
+        out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
+        out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
+        out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
+        out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
+        out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
+        out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
+
+    }
+
+    return out;
+}
+
 static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
     GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
     GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
@@ -3690,6 +5114,36 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
 
     GGML_UNUSED(data_size);
 }
+static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
+    GGML_ASSERT(interleave_block == 8);
+    constexpr int nrows_interleaved = 8;
+
+    block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
+    const block_q4_K * src = (const block_q4_K*) data;
+    block_q4_K dst_tmp[8];
+    int nrow = ggml_nrows(t);
+    int nblocks = t->ne[0] / QK_K;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i  = 0; i < nrows_interleaved; i++ ) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_Kx8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
 
 static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
     GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
@@ -3807,6 +5261,10 @@ template <> int repack(struct ggml_tensor * t, const void * da
     return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
 }
 
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
+}
+
 template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
     return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
 }
@@ -3817,44 +5275,50 @@ template <> int repack(struct ggml_tensor * t, const void *
 //}
 
 // gemv
-template 
+template 
 void gemv(int, float *, size_t, const void *, const void *, int, int);
 
-template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <>
-void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
 // gemm
-template 
+template 
 void gemm(int, float *, size_t, const void *, const void *, int, int);
 
-template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <>
-void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
@@ -3863,37 +5327,37 @@ class tensor_traits_base : public ggml::cpu::tensor_traits {
     virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
 };
 
-template  class tensor_traits : public tensor_traits_base {
+template  class tensor_traits : public tensor_traits_base {
 
     bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
         // not realy a GGML_TYPE_Q8_0 but same size.
         switch (op->op) {
-        case GGML_OP_MUL_MAT:
-            size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
-            return true;
-        case GGML_OP_MUL_MAT_ID:
-            size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
-            size = GGML_PAD(size, sizeof(int64_t));  // + padding for next bloc.
-            size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
-            return true;
-        default:
-            // GGML_ABORT("fatal error");
-            break;
+            case GGML_OP_MUL_MAT:
+                size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
+                return true;
+            case GGML_OP_MUL_MAT_ID:
+                size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
+                size = GGML_PAD(size, sizeof(int64_t));  // + padding for next bloc.
+                size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
+                return true;
+            default:
+                // GGML_ABORT("fatal error");
+                break;
         }
         return false;
     }
 
     bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
         switch (op->op) {
-        case GGML_OP_MUL_MAT:
-            forward_mul_mat(params, op);
-            return true;
-        case GGML_OP_MUL_MAT_ID:
-            forward_mul_mat_id(params, op);
-            return true;
-        default:
-            // GGML_ABORT("fatal error");
-            break;
+            case GGML_OP_MUL_MAT:
+                forward_mul_mat(params, op);
+                return true;
+            case GGML_OP_MUL_MAT_ID:
+                forward_mul_mat_id(params, op);
+                return true;
+            default:
+                // GGML_ABORT("fatal error");
+                break;
         }
         return false;
     }
@@ -3925,17 +5389,17 @@ template  class tensor_
         // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
 
         char *       wdata = static_cast(params->wdata);
-        const size_t nbw1  = ggml_row_size(GGML_TYPE_Q8_0, ne10);
+        const size_t nbw1  = ggml_row_size(PARAM_TYPE, ne10);
 
         assert(params->wsize >= nbw1 * ne11);
 
-        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
+        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
 
         int64_t i11_processed = 0;
         for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
-            quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
-                              INTER_SIZE);
+            ggml_quantize_mat_t((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
         }
+
         i11_processed = ne11 - ne11 % 4;
         for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
             from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
@@ -3944,26 +5408,28 @@ template  class tensor_
         ggml_barrier(params->threadpool);
 
         const void * src1_wdata      = params->wdata;
-        const size_t src1_col_stride = ggml_row_size(GGML_TYPE_Q8_0, ne10);
+        const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
         int64_t      src0_start      = (ith * ne01) / nth;
         int64_t      src0_end        = ((ith + 1) * ne01) / nth;
         src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
-        src0_end   = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
+        src0_end   = (src0_end   % NB_COLS) ? src0_end   + NB_COLS - (src0_end   % NB_COLS) : src0_end;
         if (src0_start >= src0_end) {
             return;
         }
 
         // If there are more than three rows in src1, use gemm; otherwise, use gemv.
         if (ne11 > 3) {
-            gemm(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
-                                                 (const char *) src0->data + src0_start * nb01,
-                                                 (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
+            gemm(ne00,
+                    (float *) ((char *) dst->data) + src0_start, ne01,
+                    (const char *) src0->data + src0_start * nb01,
+                    (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
         }
         for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
-            gemv(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
-                                                 (const char *) src0->data + src0_start * nb01,
-                                                 (const char *) src1_wdata + (src1_col_stride * iter), 1,
-                                                 src0_end - src0_start);
+            gemv(ne00,
+                    (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
+                    (const char *) src0->data + src0_start * nb01,
+                    (const char *) src1_wdata + (src1_col_stride * iter), 1,
+                    src0_end - src0_start);
         }
     }
 
@@ -3978,7 +5444,7 @@ template  class tensor_
         const int ith = params->ith;
         const int nth = params->nth;
 
-        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
+        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
 
         // we don't support permuted src0 or src1
         GGML_ASSERT(nb00 == ggml_type_size(src0->type));
@@ -4000,7 +5466,7 @@ template  class tensor_
         const int n_ids = ids->ne[0]; // n_expert_used
         const int n_as  = ne02;       // n_expert
 
-        const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
+        const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
         const size_t nbw2 = nbw1*ne11;
         const size_t nbw3 = nbw2*ne12;
 
@@ -4012,12 +5478,13 @@ template  class tensor_
         GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
                                       n_as * ne12 * sizeof(mmid_row_mapping)));
 
-        auto                      wdata             = (char *) params->wdata;
-        auto                      wdata_src1_end    = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
-        int64_t *                 matrix_row_counts = (int64_t *) (wdata_src1_end);                      // [n_as]
+        auto * wdata             = (char *)     params->wdata;
+        auto * wdata_src1_end    = (char *)     wdata + GGML_PAD(nbw3, sizeof(int64_t));
+        auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+
         struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as);  // [n_as][ne12]
 
-        // src1: float32 => block_q8_0
+        // src1: float32 => param type
         for (int64_t i12 = 0; i12 < ne12; ++i12) {
             for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
                 from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
@@ -4056,34 +5523,37 @@ template  class tensor_
                 continue;
             }
 
-            auto src0_cur = (const char *) src0->data + cur_a*nb02;
+            const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
 
             //const int64_t nr0 = ne01; // src0 rows
             const int64_t nr1 = cne1; // src1 rows
 
             int64_t src0_cur_start = (ith * ne01) / nth;
             int64_t src0_cur_end   = ((ith + 1) * ne01) / nth;
-            src0_cur_start =
-                (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
-            src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
 
-            if (src0_cur_start >= src0_cur_end) return;
+            src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
+            src0_cur_end   = (src0_cur_end   % NB_COLS) ? src0_cur_end   + NB_COLS - (src0_cur_end   % NB_COLS) : src0_cur_end;
+
+            if (src0_cur_start >= src0_cur_end) {
+                return;
+            }
 
             for (int ir1 = 0; ir1 < nr1; ir1++) {
                 struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
-                const int id       = row_mapping.i1; // selected expert index
 
-                const int64_t  i11 = id % ne11;
-                const int64_t  i12 = row_mapping.i2; // row index in src1
+                const int id = row_mapping.i1; // selected expert index
 
-                const int64_t  i1 = id;  // selected expert index
-                const int64_t  i2 = i12; // row
+                const int64_t i11 = id % ne11;
+                const int64_t i12 = row_mapping.i2; // row index in src1
 
-                auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
+                const int64_t i1 = id;  // selected expert index
+                const int64_t i2 = i12; // row
 
-                gemv(
-                        ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
-                        ne01,                    src0_cur + src0_cur_start * nb01,
+                const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
+
+                gemv(ne00,
+                        (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
+                        src0_cur + src0_cur_start * nb01,
                         src1_col, 1, src0_cur_end - src0_cur_start);
             }
         }
@@ -4098,12 +5568,13 @@ template  class tensor_
 };
 
 // instance for Q4
-static const tensor_traits q4_0_4x4_q8_0;
-static const tensor_traits q4_0_4x8_q8_0;
-static const tensor_traits q4_0_8x8_q8_0;
+static const tensor_traits q4_0_4x4_q8_0;
+static const tensor_traits q4_0_4x8_q8_0;
+static const tensor_traits q4_0_8x8_q8_0;
+static const tensor_traits q4_K_8x8_q8_K;
 
 // instance for IQ4
-static const tensor_traits iq4_nl_4x4_q8_0;
+static const tensor_traits iq4_nl_4x4_q8_0;
 
 }  // namespace ggml::cpu::aarch64
 
@@ -4124,6 +5595,12 @@ static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(con
                 return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
             }
         }
+    } else if (cur->type == GGML_TYPE_Q4_K) {
+        if (ggml_cpu_has_avx2()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &ggml::cpu::aarch64::q4_K_8x8_q8_K;
+            }
+        }
     } else if (cur->type == GGML_TYPE_IQ4_NL) {
         if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
             if (cur->ne[1] % 4 == 0) {
@@ -4135,10 +5612,11 @@ static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(con
     return nullptr;
 }
 
-static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+static enum ggml_status ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
     tensor->extra = (void *) const_cast(ggml_aarch64_get_optimal_repack_type(tensor));
 
     GGML_UNUSED(buffer);
+    return GGML_STATUS_SUCCESS;
 }
 
 static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
index 7f7d210cb..e4af07635 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
@@ -4,13 +4,13 @@
 
 #include "ggml.h"
 #include "ggml-impl.h"
+
 #include  // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
 //#include 
 #include 
 #include  // memcpy
 #include    // fabsf
 
-
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -69,33 +69,16 @@ struct ggml_compute_params {
 #endif
 
 #if defined(__ARM_FEATURE_SVE)
-#include 
 #include 
 #endif
 
-// 16-bit float
-// on Arm, we use __fp16
-// on x86, we use uint16_t
 #if defined(__ARM_NEON)
 
-// if YCM cannot find , make a symbolic link to it, for example:
-//
-//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
-//
-#include 
-
+// ref: https://github.com/ggml-org/llama.cpp/pull/5404
 #ifdef _MSC_VER
-
-typedef uint16_t ggml_fp16_internal_t;
-
 #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
-
 #else
-
-typedef __fp16 ggml_fp16_internal_t;
-
 #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
-
 #endif // _MSC_VER
 
 #if !defined(__aarch64__)
@@ -340,8 +323,6 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
 #else
 #ifdef __POWER9_VECTOR__
 #include 
-#undef bool
-#define bool _Bool
 #else
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include 
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c
index 8d5e3e20b..91a81bdc3 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c
@@ -719,28 +719,28 @@ static inline __m128i packNibbles( __m256i bytes ) {
 }
 #endif  //__loongarch_asx
 
-void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
+void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
     quantize_row_q4_0_ref(x, y, k);
 }
 
-void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
+void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
     quantize_row_q4_1_ref(x, y, k);
 }
 
-void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
+void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
     quantize_row_q5_0_ref(x, y, k);
 }
 
-void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
+void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
     quantize_row_q5_1_ref(x, y, k);
 }
 
-void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) {
+void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(QK8_0 == 32);
     assert(k % QK8_0 == 0);
     const int nb = k / QK8_0;
 
-    block_q8_0 * restrict y = vy;
+    block_q8_0 * GGML_RESTRICT y = vy;
 
 #if defined(__ARM_NEON)
     for (int i = 0; i < nb; i++) {
@@ -891,15 +891,15 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
     }
 #elif defined(__riscv_v_intrinsic)
 
-    size_t vl = __riscv_vsetvl_e32m4(QK8_0);
+    size_t vl = QK8_0;
 
     for (int i = 0; i < nb; i++) {
         // load elements
-        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
+        vfloat32m8_t v_x   = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);
 
-        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
         vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0f, vl);
-        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
         float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
 
         const float d = amax / ((1 << 7) - 1);
@@ -907,14 +907,14 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
 
         y[i].d = GGML_FP32_TO_FP16(d);
 
-        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+        vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
 
         // convert to integer
-        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
-        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+        vint16m4_t   vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
+        vint8m2_t    vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
 
         // store result
-        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+        __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
     }
 
 #elif defined(__POWER9_VECTOR__)
@@ -1050,11 +1050,11 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
 #endif
 }
 
-void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) {
+void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(k % QK8_1 == 0);
     const int nb = k / QK8_1;
 
-    block_q8_1 * restrict y = vy;
+    block_q8_1 * GGML_RESTRICT y = vy;
 
 #if defined(__ARM_NEON)
     for (int i = 0; i < nb; i++) {
@@ -1229,15 +1229,15 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
     }
 #elif defined(__riscv_v_intrinsic)
 
-    size_t vl = __riscv_vsetvl_e32m4(QK8_1);
+    size_t vl = QK8_1;
 
     for (int i = 0; i < nb; i++) {
         // load elements
-        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
+        vfloat32m8_t v_x   = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);
 
-        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
         vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0, vl);
-        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
         float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
 
         const float d  = amax / ((1 << 7) - 1);
@@ -1245,18 +1245,18 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
 
         y[i].d = GGML_FP32_TO_FP16(d);
 
-        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+        vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
 
         // convert to integer
-        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
-        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+        vint16m4_t   vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
+        vint8m2_t    vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
 
         // store result
-        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+        __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
 
         // compute sum for y[i].s
         vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
-        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
+        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);
 
         // set y[i].s
         int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
@@ -1428,8 +1428,8 @@ static inline int nearest_int(float fval) {
     return (i & 0x007fffff) - 0x00400000;
 }
 
-static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
-        const float * restrict qw) {
+static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, int rmse_type,
+        const float * GGML_RESTRICT qw) {
     float max = 0;
     float amax = 0;
     for (int i = 0; i < n; ++i) {
@@ -1497,7 +1497,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
     return scale;
 }
 
-static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
+static float make_q3_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, bool do_rmse) {
     float max = 0;
     float amax = 0;
     for (int i = 0; i < n; ++i) {
@@ -1556,7 +1556,7 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
     return 1/iscale;
 }
 
-static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
+static float make_qkx1_quants(int n, int nmax, const float * GGML_RESTRICT x, uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min,
         int ntry, float alpha) {
     float min = x[0];
     float max = x[0];
@@ -1599,8 +1599,8 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
     return scale;
 }
 
-static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
-        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights,
+        uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, uint8_t * GGML_RESTRICT Laux,
         float rmin, float rdelta, int nstep, bool use_mad) {
     float min = x[0];
     float max = x[0];
@@ -1680,7 +1680,7 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const f
     return scale;
 }
 
-static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
+static inline void get_scale_min_k4(int j, const uint8_t * GGML_RESTRICT q, uint8_t * GGML_RESTRICT d, uint8_t * GGML_RESTRICT m) {
     if (j < 4) {
         *d = q[j] & 63; *m = q[j + 4] & 63;
     } else {
@@ -1691,51 +1691,51 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
 
 //========================- 2-bit (de)-quantization
 
-void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
+void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     quantize_row_q2_K_ref(x, vy, k);
 }
 
 //========================= 3-bit (de)-quantization
 
-void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
+void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     quantize_row_q3_K_ref(x, vy, k);
 }
 
 // ====================== 4-bit (de)-quantization
 
-void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
+void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(k % QK_K == 0);
-    block_q4_K * restrict y = vy;
+    block_q4_K * GGML_RESTRICT y = vy;
     quantize_row_q4_K_ref(x, y, k);
 }
 
 // ====================== 5-bit (de)-quantization
 
-void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
+void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(k % QK_K == 0);
-    block_q5_K * restrict y = vy;
+    block_q5_K * GGML_RESTRICT y = vy;
     quantize_row_q5_K_ref(x, y, k);
 }
 
 // ====================== 6-bit (de)-quantization
 
-void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
+void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(k % QK_K == 0);
-    block_q6_K * restrict y = vy;
+    block_q6_K * GGML_RESTRICT y = vy;
     quantize_row_q6_K_ref(x, y, k);
 }
 
 // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
 
-void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) {
+void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(k % QK_K == 0);
-    block_tq1_0 * restrict y = vy;
+    block_tq1_0 * GGML_RESTRICT y = vy;
     quantize_row_tq1_0_ref(x, y, k);
 }
 
-void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) {
+void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(k % QK_K == 0);
-    block_tq2_0 * restrict y = vy;
+    block_tq2_0 * GGML_RESTRICT y = vy;
     quantize_row_tq2_0_ref(x, y, k);
 }
 
@@ -1743,11 +1743,11 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1
 
 //===================================== Q8_K ==============================================
 
-void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
+void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
 #ifdef __wasm_simd128__
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
-    block_q8_K * restrict yc = y; // Cast to proper type
+    block_q8_K * GGML_RESTRICT yc = y; // Cast to proper type
 
     for (int i = 0; i < nb; i++) {
         const float * x_block = x + i * QK_K;
@@ -1909,7 +1909,7 @@ static inline __m128i get_scale_shuffle(int i) {
 }
 #endif
 
-void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
 
@@ -1924,23 +1924,23 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     UNUSED(by);
     UNUSED(bs);
 
-    const block_q4_0 * restrict x = vx;
-    const block_q8_0 * restrict y = vy;
+    const block_q4_0 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
 
 #if defined(__ARM_FEATURE_MATMUL_INT8)
     if (nrc == 2) {
-        const block_q4_0 * restrict vx0 = vx;
-        const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
-        const block_q8_0 * restrict vy0 = vy;
-        const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
+        const block_q4_0 * GGML_RESTRICT vx0 = vx;
+        const block_q4_0 * GGML_RESTRICT vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
+        const block_q8_0 * GGML_RESTRICT vy0 = vy;
+        const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
 
         float32x4_t sumv0 = vdupq_n_f32(0.0f);
 
         for (int i = 0; i < nb; i++) {
-            const block_q4_0 * restrict b_x0 = &vx0[i];
-            const block_q4_0 * restrict b_x1 = &vx1[i];
-            const block_q8_0 * restrict b_y0 = &vy0[i];
-            const block_q8_0 * restrict b_y1 = &vy1[i];
+            const block_q4_0 * GGML_RESTRICT b_x0 = &vx0[i];
+            const block_q4_0 * GGML_RESTRICT b_x1 = &vx1[i];
+            const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i];
+            const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i];
 
             const uint8x16_t m4b = vdupq_n_u8(0x0F);
             const int8x16_t  s8b = vdupq_n_s8(0x8);
@@ -2017,10 +2017,10 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
                 const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
 
                 for (; ib + 1 < nb; ib += 2) {
-                    const block_q4_0 * restrict x0 = &x[ib + 0];
-                    const block_q4_0 * restrict x1 = &x[ib + 1];
-                    const block_q8_0 * restrict y0 = &y[ib + 0];
-                    const block_q8_0 * restrict y1 = &y[ib + 1];
+                    const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
+                    const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
+                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
+                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 
                     // load x
                     const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
@@ -2063,10 +2063,10 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
                 const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
 
                 for (; ib + 1 < nb; ib += 2) {
-                    const block_q4_0 * restrict x0 = &x[ib + 0];
-                    const block_q4_0 * restrict x1 = &x[ib + 1];
-                    const block_q8_0 * restrict y0 = &y[ib + 0];
-                    const block_q8_0 * restrict y1 = &y[ib + 1];
+                    const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
+                    const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
+                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
+                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 
                     // load x
                     const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
@@ -2104,10 +2104,10 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
                 const svbool_t pl16 = svnot_b_z(ph32, ph16);
 
                 for (; ib + 1 < nb; ib += 2) {
-                    const block_q4_0 * restrict x0 = &x[ib + 0];
-                    const block_q4_0 * restrict x1 = &x[ib + 1];
-                    const block_q8_0 * restrict y0 = &y[ib + 0];
-                    const block_q8_0 * restrict y1 = &y[ib + 1];
+                    const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
+                    const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
+                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
+                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 
                     // load x
                     const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
@@ -2144,10 +2144,10 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
     for (; ib + 1 < nb; ib += 2) {
-        const block_q4_0 * restrict x0 = &x[ib + 0];
-        const block_q4_0 * restrict x1 = &x[ib + 1];
-        const block_q8_0 * restrict y0 = &y[ib + 0];
-        const block_q8_0 * restrict y1 = &y[ib + 1];
+        const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
+        const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
+        const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
+        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 
         const uint8x16_t m4b = vdupq_n_u8(0x0F);
         const int8x16_t  s8b = vdupq_n_s8(0x8);
@@ -2189,10 +2189,10 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     const v128_t s8b = wasm_i8x16_splat(0x8);
 
     for (; ib + 1 < nb; ib += 2) {
-        const block_q4_0 * restrict x0 = &x[ib];
-        const block_q4_0 * restrict x1 = &x[ib + 1];
-        const block_q8_0 * restrict y0 = &y[ib];
-        const block_q8_0 * restrict y1 = &y[ib + 1];
+        const block_q4_0 * GGML_RESTRICT x0 = &x[ib];
+        const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
+        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
+        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 
         // Load and process x0
         v128_t v0_0 = wasm_v128_load(x0->qs);
@@ -2391,33 +2391,31 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
 
     sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
 #elif defined(__riscv_v_intrinsic)
-    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+    size_t vl = qk / 2;
 
     for (; ib < nb; ++ib) {
         // load elements
-        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
 
-        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
-        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
+        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
 
         // mask and store lower part of x, and then upper part
-        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
-        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+        vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
+        vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
 
-        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
-        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+        vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
+        vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
 
         // subtract offset
-        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
-        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
+        vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
+        vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
 
-        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
-        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
+        vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
 
         vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
 
         int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
 
@@ -2609,7 +2607,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     *s = sumf;
 }
 
-void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_1;
     const int nb = n / qk;
 
@@ -2624,24 +2622,24 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     UNUSED(by);
     UNUSED(bs);
 
-    const block_q4_1 * restrict x = vx;
-    const block_q8_1 * restrict y = vy;
+    const block_q4_1 * GGML_RESTRICT x = vx;
+    const block_q8_1 * GGML_RESTRICT y = vy;
 
 #if defined(__ARM_FEATURE_MATMUL_INT8)
     if (nrc == 2) {
-        const block_q4_1 * restrict vx0 = vx;
-        const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
-        const block_q8_1 * restrict vy0 = vy;
-        const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
+        const block_q4_1 * GGML_RESTRICT vx0 = vx;
+        const block_q4_1 * GGML_RESTRICT vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
+        const block_q8_1 * GGML_RESTRICT vy0 = vy;
+        const block_q8_1 * GGML_RESTRICT vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
 
         float32x4_t sumv0 = vdupq_n_f32(0.0f);
         float32x4_t summs0 = vdupq_n_f32(0.0f);
 
         for (int i = 0; i < nb; i++) {
-            const block_q4_1 * restrict b_x0 = &vx0[i];
-            const block_q4_1 * restrict b_x1 = &vx1[i];
-            const block_q8_1 * restrict b_y0 = &vy0[i];
-            const block_q8_1 * restrict b_y1 = &vy1[i];
+            const block_q4_1 * GGML_RESTRICT b_x0 = &vx0[i];
+            const block_q4_1 * GGML_RESTRICT b_x1 = &vx1[i];
+            const block_q8_1 * GGML_RESTRICT b_y0 = &vy0[i];
+            const block_q8_1 * GGML_RESTRICT b_y1 = &vy1[i];
 
             float32_t summs_t[4] = {
                 GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
@@ -2715,10 +2713,10 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     float summs = 0;
 
     for (; ib + 1 < nb; ib += 2) {
-        const block_q4_1 * restrict x0 = &x[ib + 0];
-        const block_q4_1 * restrict x1 = &x[ib + 1];
-        const block_q8_1 * restrict y0 = &y[ib + 0];
-        const block_q8_1 * restrict y1 = &y[ib + 1];
+        const block_q4_1 * GGML_RESTRICT x0 = &x[ib + 0];
+        const block_q4_1 * GGML_RESTRICT x1 = &x[ib + 1];
+        const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0];
+        const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];
 
         summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
 
@@ -2783,29 +2781,27 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
 
     sumf = hsum_float_8(acc) + summs;
 #elif defined(__riscv_v_intrinsic)
-    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+    size_t vl = qk / 2;
 
     for (; ib < nb; ++ib) {
         // load elements
-        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
 
-        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
-        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
+        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
 
         // mask and store lower part of x, and then upper part
-        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
-        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+        vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
+        vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
 
-        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
-        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+        vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
+        vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
 
-        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
-        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
+        vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
 
         vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
 
         int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
 
@@ -2931,7 +2927,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     *s = sumf;
 }
 
-void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
 
@@ -2946,8 +2942,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     UNUSED(by);
     UNUSED(bs);
 
-    const block_q5_0 * restrict x = vx;
-    const block_q8_0 * restrict y = vy;
+    const block_q5_0 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
 
 #if defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
@@ -2960,10 +2956,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     uint64_t tmp1[4];
 
     for (; ib + 1 < nb; ib += 2) {
-        const block_q5_0 * restrict x0 = &x[ib];
-        const block_q5_0 * restrict x1 = &x[ib + 1];
-        const block_q8_0 * restrict y0 = &y[ib];
-        const block_q8_0 * restrict y1 = &y[ib + 1];
+        const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
+        const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1];
+        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
+        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 
         const uint8x16_t m4b = vdupq_n_u8(0x0F);
 
@@ -3024,8 +3020,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
 
     // TODO: check if unrolling this is better
     for (; ib < nb; ++ib) {
-        const block_q5_0 * restrict x0 = &x[ib];
-        const block_q8_0 * restrict y0 = &y[ib];
+        const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
+        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
 
         const v128_t m4b  = wasm_i8x16_splat(0x0F);
 
@@ -3132,65 +3128,33 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
 
     sumf = hsum_float_8(acc);
 #elif defined(__riscv_v_intrinsic)
-    uint32_t qh;
-
-    size_t vl = __riscv_vsetvl_e8m1(qk/2);
-
-    // These temporary registers are for masking and shift operations
-    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
-    vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
-
-    vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
-    vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+    size_t vl;
+    size_t vlenb = __riscv_vlenb();
 
     for (; ib < nb; ++ib) {
-        memcpy(&qh, x[ib].qh, sizeof(uint32_t));
+        vl = qk / 2;
+        vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
+        vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
+        vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
+        vint8m2_t v0c;
+        if (vlenb == 16) {
+            v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
+        } else {
+            v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
+            v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
+        }
 
-        // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
-        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
-        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+        vl = qk;
+        vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
+        qh = __riscv_vmnand_mm_b4(qh, qh, vl);
+        vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
+        vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
+        vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
+        vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
+        vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
+        int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
 
-        // ((qh & (1u << (j + 16))) >> (j + 12));
-        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
-        vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
-
-        // narrowing
-        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
-        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
-
-        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
-        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
-
-        // load
-        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
-
-        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
-        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
-
-        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
-        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
-
-        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
-        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
-
-        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
-        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
-
-        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
-        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
-
-        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
-        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
-
-        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
-
-        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
-
-        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
+        sumf += (GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)) * sumi;
     }
 
 #elif defined(__POWER9_VECTOR__)
@@ -3286,7 +3250,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     *s = sumf;
 }
 
-void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_1;
     const int nb = n / qk;
 
@@ -3301,8 +3265,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     UNUSED(by);
     UNUSED(bs);
 
-    const block_q5_1 * restrict x = vx;
-    const block_q8_1 * restrict y = vy;
+    const block_q5_1 * GGML_RESTRICT x = vx;
+    const block_q8_1 * GGML_RESTRICT y = vy;
 
 #if defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
@@ -3318,10 +3282,10 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     uint64_t tmp1[4];
 
     for (; ib + 1 < nb; ib += 2) {
-        const block_q5_1 * restrict x0 = &x[ib];
-        const block_q5_1 * restrict x1 = &x[ib + 1];
-        const block_q8_1 * restrict y0 = &y[ib];
-        const block_q8_1 * restrict y1 = &y[ib + 1];
+        const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
+        const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1];
+        const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
+        const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];
 
         const uint8x16_t m4b = vdupq_n_u8(0x0F);
 
@@ -3387,8 +3351,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
 
     // TODO: check if unrolling this is better
     for (; ib < nb; ++ib) {
-        const block_q5_1 * restrict x0 = &x[ib];
-        const block_q8_1 * restrict y0 = &y[ib];
+        const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
+        const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
 
         summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
 
@@ -3503,60 +3467,30 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
 
     sumf = hsum_float_8(acc) + summs;
 #elif defined(__riscv_v_intrinsic)
-    uint32_t qh;
-
-    size_t vl = __riscv_vsetvl_e8m1(qk/2);
-
-    // temporary registers for shift operations
-    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
-    vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+    size_t vl;
+    size_t vlenb = __riscv_vlenb();
 
     for (; ib < nb; ++ib) {
-        memcpy(&qh, x[ib].qh, sizeof(uint32_t));
+        vl = qk / 2;
+        vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
+        vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
+        vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
+        vint8m2_t v0c;
+        if (vlenb == 16) {
+            v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
+        } else {
+            v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
+            v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
+        }
 
-        // load qh
-        vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
-
-        // ((qh >> (j +  0)) << 4) & 0x10;
-        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
-        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
-        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
-
-        // ((qh >> (j + 12))     ) & 0x10;
-        vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
-        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
-
-        // narrowing
-        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
-        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
-
-        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
-        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
-
-        // load
-        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
-
-        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
-        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
-
-        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
-        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
-
-        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
-        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
-
-        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
-        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
-
-        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
-        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
-
-        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
-
-        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+        vl = qk;
+        vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
+        vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
+        vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
+        vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
+        vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
+        vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
+        int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
 
         sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
     }
@@ -3660,7 +3594,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     *s = sumf;
 }
 
-void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
 
@@ -3675,24 +3609,24 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     UNUSED(by);
     UNUSED(bs);
 
-    const block_q8_0 * restrict x = vx;
-    const block_q8_0 * restrict y = vy;
+    const block_q8_0 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
 
 #if defined(__ARM_FEATURE_MATMUL_INT8)
     if (nrc == 2) {
-        const block_q8_0 * restrict vx0 = vx;
-        const block_q8_0 * restrict vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);
-        const block_q8_0 * restrict vy0 = vy;
-        const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
+        const block_q8_0 * GGML_RESTRICT vx0 = vx;
+        const block_q8_0 * GGML_RESTRICT vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);
+        const block_q8_0 * GGML_RESTRICT vy0 = vy;
+        const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
 
         float32x4_t sumv0 = vdupq_n_f32(0.0f);
 
         for (int i = 0; i < nb; i++) {
-            const block_q8_0 * restrict b_x0 = &vx0[i];
-            const block_q8_0 * restrict b_y0 = &vy0[i];
+            const block_q8_0 * GGML_RESTRICT b_x0 = &vx0[i];
+            const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i];
 
-            const block_q8_0 * restrict b_x1 = &vx1[i];
-            const block_q8_0 * restrict b_y1 = &vy1[i];
+            const block_q8_0 * GGML_RESTRICT b_x1 = &vx1[i];
+            const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i];
 
             const int8x16_t x0_l = vld1q_s8(b_x0->qs);
             const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
@@ -3757,10 +3691,10 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
                 const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
 
                 for (; ib + 1 < nb; ib += 2) {
-                    const block_q8_0 * restrict x0 = &x[ib + 0];
-                    const block_q8_0 * restrict x1 = &x[ib + 1];
-                    const block_q8_0 * restrict y0 = &y[ib + 0];
-                    const block_q8_0 * restrict y1 = &y[ib + 1];
+                    const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
+                    const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
+                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
+                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 
                     // load x
                     const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
@@ -3788,10 +3722,10 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
             {
                 //printf("sve256");
                 for (; ib + 1 < nb; ib += 2) {
-                    const block_q8_0 * restrict x0 = &x[ib + 0];
-                    const block_q8_0 * restrict x1 = &x[ib + 1];
-                    const block_q8_0 * restrict y0 = &y[ib + 0];
-                    const block_q8_0 * restrict y1 = &y[ib + 1];
+                    const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
+                    const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
+                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
+                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 
                     // load x
                     const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
@@ -3824,10 +3758,10 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
                 svfloat32_t sumv00 = svdup_n_f32(0.0f);
 
                 for (; ib + 1 < nb; ib += 2) {
-                    const block_q8_0 * restrict x0 = &x[ib + 0];
-                    const block_q8_0 * restrict x1 = &x[ib + 1];
-                    const block_q8_0 * restrict y0 = &y[ib + 0];
-                    const block_q8_0 * restrict y1 = &y[ib + 1];
+                    const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
+                    const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
+                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
+                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 
                     //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
                     // and add them to make one 64 element vector
@@ -3867,10 +3801,10 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
     for (; ib + 1 < nb; ib += 2) {
-        const block_q8_0 * restrict x0 = &x[ib + 0];
-        const block_q8_0 * restrict x1 = &x[ib + 1];
-        const block_q8_0 * restrict y0 = &y[ib + 0];
-        const block_q8_0 * restrict y1 = &y[ib + 1];
+        const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
+        const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
+        const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
+        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 
         const int8x16_t x0_0 = vld1q_s8(x0->qs);
         const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
@@ -3897,8 +3831,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     v128_t sumv = wasm_f32x4_splat(0.0f);
 
     for (; ib < nb; ++ib) {
-        const block_q8_0 * restrict x0 = &x[ib];
-        const block_q8_0 * restrict y0 = &y[ib];
+        const block_q8_0 * GGML_RESTRICT x0 = &x[ib];
+        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
 
         const v128_t x0_0 = wasm_v128_load(x0->qs);
         const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
@@ -3970,17 +3904,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
 
     sumf = hsum_float_8(accum);
 #elif defined(__riscv_v_intrinsic)
-    size_t vl = __riscv_vsetvl_e8m1(qk);
+    size_t vl = qk;
 
     for (; ib < nb; ++ib) {
         // load elements
-        vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl);
-        vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
+        vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);
+        vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
 
-        vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
+        vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);
 
         vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
-        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
+        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);
 
         int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
 
@@ -4080,15 +4014,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     *s = sumf;
 }
 
-void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(nrc == 1);
     UNUSED(nrc);
     UNUSED(bx);
     UNUSED(by);
     UNUSED(bs);
 
-    const block_tq1_0 * restrict x = vx;
-    const block_q8_K  * restrict y = vy;
+    const block_tq1_0 * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -4403,15 +4337,15 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
 #endif
 }
 
-void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(nrc == 1);
     UNUSED(nrc);
     UNUSED(bx);
     UNUSED(by);
     UNUSED(bs);
 
-    const block_tq2_0 * restrict x = vx;
-    const block_q8_K  * restrict y = vy;
+    const block_tq2_0 * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -4575,19 +4509,264 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
 #endif
 }
 
-void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(nrc == 1);
     UNUSED(nrc);
     UNUSED(bx);
     UNUSED(by);
     UNUSED(bs);
 
-    const block_q2_K * restrict x = vx;
-    const block_q8_K * restrict y = vy;
+    const block_q2_K * GGML_RESTRICT x = vx;
+    const block_q8_K * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
-#ifdef __ARM_NEON
+#ifdef __ARM_FEATURE_SVE
+    const int vector_length = svcntb()*8;
+    const svuint8_t m3s = svdup_n_u8(0x3);
+    const svuint32_t m4s = svdup_n_u32(0xF);
+    const svint32_t vzero_sv = svdup_n_s32(0);
+    svfloat32_t acc_sum = svdup_n_f32(0);
+    svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);
+
+    switch (vector_length) {
+        case 128:
+            for (int i = 0; i < nb; ++i) {
+                const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+                svfloat32_t d_broad = svdup_n_f32((float32_t)d);
+                const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+                svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
+
+                const uint8_t * GGML_RESTRICT q2 = x[i].qs;
+                const int8_t  * GGML_RESTRICT q8_sv = y[i].qs;
+                const uint8_t * GGML_RESTRICT sc = x[i].scales;
+
+                svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);
+                const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
+
+                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);
+                const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
+
+                svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);
+                svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);
+
+                const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));
+
+                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);
+                const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
+
+                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);
+                const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
+
+                q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);
+                q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);
+
+                svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));
+
+                svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));
+
+                acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);
+
+                svint32_t sumi1 = svdup_n_s32(0);
+
+                {
+                    const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);
+                    svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));
+                    svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+                    const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));
+
+                    const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));
+
+
+                    const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));
+
+                    //-------------------------------
+
+                    q2 += 32;
+                    const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));
+                    const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));
+
+                    const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));
+
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));
+
+
+                    const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));
+
+
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));
+                }
+                acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);
+            }
+            *s = svaddv_f32(svptrue_b32(), acc_sum);
+            break;
+
+        case 256:
+        case 512:
+            for (int i = 0; i < nb; ++i) {
+                const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+                svfloat32_t d_broad = svdup_n_f32((float32_t)d);
+                const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+                svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
+
+                const uint8_t * GGML_RESTRICT q2 = x[i].qs;
+                const int8_t  * GGML_RESTRICT q8_sv = y[i].qs;
+                const uint8_t * GGML_RESTRICT sc = x[i].scales;
+
+                const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;
+                const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));
+                const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));
+                svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);
+
+                const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);
+                const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));
+                const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));
+
+                svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);
+
+                svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));
+
+                acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);
+
+                svint32_t sumi1 = svdup_n_s32(0);
+
+                {
+                    const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
+                    svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));
+                    svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                    svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));
+                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                    svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));
+                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));
+                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));
+                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
+
+                    q2 += 32;
+
+                    const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));
+                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));
+                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));
+                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));
+                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
+
+                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));
+                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));
+                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
+                }
+                acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);
+            }
+            *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);
+            break;
+
+        default:
+            assert(false && "Unsupported vector length");
+            break;
+    }
+
+#elif __ARM_NEON
     const uint8x16_t m3 = vdupq_n_u8(0x3);
     const uint8x16_t m4 = vdupq_n_u8(0xF);
 
@@ -4602,9 +4781,9 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
         const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 
-        const uint8_t * restrict q2 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
-        const uint8_t * restrict sc = x[i].scales;
+        const uint8_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT sc = x[i].scales;
 
         const uint8x16_t mins_and_scales = vld1q_u8(sc);
         const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
@@ -4667,8 +4846,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
         const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 
-        const uint8_t * restrict q2 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
         const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
@@ -4734,8 +4913,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
         const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 
-        const uint8_t * restrict q2 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         // load mins and scales from block_q2_K.scales[QK_K/16]
         const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
@@ -4929,84 +5108,182 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
 #elif defined __riscv_v_intrinsic
 
+    const int vector_length = __riscv_vlenb() * 8;
     float sumf = 0;
-    uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
-                            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
 
-    for (int i = 0; i < nb; ++i) {
+    uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
+    uint8_t atmp[16];
 
-        const uint8_t * q2 = x[i].qs;
-        const  int8_t * q8 = y[i].qs;
-        const uint8_t * sc = x[i].scales;
+    switch (vector_length) {
+    case 256:
+        for (int i = 0; i < nb; ++i) {
+            const uint8_t * q2 = x[i].qs;
+            const int8_t *  q8 = y[i].qs;
+            const uint8_t * sc = x[i].scales;
 
-        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+            const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+            const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 
-        size_t vl = 16;
+            size_t vl = 16;
 
-        vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
-        vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
+            vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
+            vuint8m1_t aux    = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
 
-        vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
+            vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
 
-        vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
-        vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
-        vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
-        vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
-        vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+            vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
+            vuint8mf2_t mins8    = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
+            vint16m1_t  mins     = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
+            vint32m2_t  prod     = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
+            vint32m1_t  vsums    = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
 
-        sumf  += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
+            sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
 
-        vl = 32;
+            vl = 32;
 
-        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
-        vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
+            vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+            vuint8m1_t v_b   = __riscv_vle8_v_u8m1(temp_01, vl);
 
-        uint8_t is=0;
-        int isum=0;
+            uint8_t is   = 0;
+            int     isum = 0;
 
-        for (int j = 0; j < QK_K/128; ++j) {
-            // load Q2
-            vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
+            for (int j = 0; j < QK_K / 128; ++j) {
+                // load Q2
+                vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
 
-            vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
-            vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl);
-            vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl);
-            vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
+                vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
+                vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
+                vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
+                vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
 
-            // duplicate scale elements for product
-            vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl);
-            vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl);
-            vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl);
-            vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
+                // duplicate scale elements for product
+                vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
+                vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
+                vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
+                vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
 
-            vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
-            vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
-            vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
-            vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
+                vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
+                vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
+                vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
+                vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
 
-            // load Q8
-            vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
-            vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
-            vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl);
-            vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
+                // load Q8
+                vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+                vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
+                vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
+                vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
 
-            vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
-            vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
-            vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
-            vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
+                vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
+                vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
+                vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
+                vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
 
-            vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
-            vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
+                vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
+                vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
 
-            isum += __riscv_vmv_x_s_i32m1_i32(isum1);
+                isum += __riscv_vmv_x_s_i32m1_i32(isum1);
 
-            q2+=32;  q8+=128;  is=8;
+                q2 += 32;
+                q8 += 128;
+                is = 8;
+            }
 
+            sumf += dall * isum;
         }
+        break;
+    case 128:
+        for (int i = 0; i < nb; ++i) {
+            const uint8_t * q2 = x[i].qs;
+            const  int8_t * q8 = y[i].qs;
+            const uint8_t * sc = x[i].scales;
+            const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+            const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+            uint8_t *patmp = atmp;
+            int vsums;
+            int tmp;
+            __asm__ __volatile__(
+                "vsetivli zero, 16, e8, m1\n\t"
+                "vmv.v.x v8, zero\n\t"
+                "vle8.v v1, (%[sc])\n\t"
+                "vand.vi v0, v1, 0xF\n\t"
+                "vsrl.vi v1, v1, 4\n\t"
+                "vse8.v v0, (%[scale])\n\t"
+                "vsetivli zero, 16, e16, m2\n\t"
+                "vle16.v v2, (%[bsums])\n\t"
+                "vzext.vf2 v0, v1\n\t"
+                "vwmul.vv v4, v0, v2\n\t"
+                "vsetivli zero, 16, e32, m4\n\t"
+                "vredsum.vs v8, v4, v8\n\t"
+                "vmv.x.s %[vsums], v8"
+                : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
+                : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
+                : "memory"
+                , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
+                , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
+                , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
+                , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+            );
+            sumf += dmin * vsums;
+            int isum = 0;
 
-        sumf += dall * isum;
+            for (int j = 0; j < QK_K/128; ++j) {
+                __asm__ __volatile__(
+                    "vsetvli zero, %[vl32], e8, m2\n\t"
+                    "vle8.v v0, (%[q2])\n\t"
+                    "vsrl.vi v2, v0, 2\n\t"
+                    "vsrl.vi v4, v0, 4\n\t"
+                    "vsrl.vi v6, v0, 6\n\t"
+                    "vand.vi v0, v0, 0x3\n\t"
+                    "vand.vi v2, v2, 0x3\n\t"
+                    "vand.vi v4, v4, 0x3\n\t"
+                    "vsetvli zero, %[vl128], e8, m8\n\t"
+                    "vle8.v v8, (%[q8])\n\t"
+                    "vsetvli zero, %[vl64], e8, m4\n\t"
+                    "vwmul.vv v16, v0, v8\n\t"
+                    "vwmul.vv v24, v4, v12\n\t"
+                    "vsetivli zero, 16, e16, m2\n\t"
+                    "vmv.v.x v0, zero\n\t"
+                    "vwredsum.vs v10, v16, v0\n\t"
+                    "vwredsum.vs v9, v18, v0\n\t"
+                    "vwredsum.vs v8, v20, v0\n\t"
+                    "vwredsum.vs v7, v22, v0\n\t"
+                    "vwredsum.vs v11, v24, v0\n\t"
+                    "vwredsum.vs v12, v26, v0\n\t"
+                    "vwredsum.vs v13, v28, v0\n\t"
+                    "vwredsum.vs v14, v30, v0\n\t"
+                    "vsetivli zero, 4, e32, m1\n\t"
+                    "vslideup.vi v10, v9, 1\n\t"
+                    "vslideup.vi v8, v7, 1\n\t"
+                    "vslideup.vi v11, v12, 1\n\t"
+                    "vslideup.vi v13, v14, 1\n\t"
+                    "vslideup.vi v10, v8, 2\n\t"
+                    "vslideup.vi v11, v13, 2\n\t"
+                    "vsetivli zero, 8, e32, m2\n\t"
+                    "vle8.v v15, (%[scale])\n\t"
+                    "vzext.vf4 v12, v15\n\t"
+                    "vmul.vv v10, v10, v12\n\t"
+                    "vredsum.vs v0, v10, v0\n\t"
+                    "vmv.x.s %[tmp], v0\n\t"
+                    "add %[isum], %[isum], %[tmp]"
+                    : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
+                    : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
+                    , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
+                    : "memory"
+                    , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
+                    , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
+                    , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
+                    , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+                );
+                q2 += 32; q8 += 128; patmp += 8;
+            }
 
+            sumf += dall * isum;
+        }
+        break;
+    default:
+        assert(false && "Unsupported vector length");
+        break;
     }
 
     *s = sumf;
@@ -5061,8 +5338,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         vector signed int vsumi6 = v0;
         vector signed int vsumi7 = v0;
 
-        const uint8_t * restrict q2 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         for (int j = 0; j < QK_K/128; ++j) {
             __builtin_prefetch(q2, 0, 1);
@@ -5153,8 +5430,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
         const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 
-        const uint8_t * restrict q2 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
         const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
@@ -5247,7 +5524,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 #endif
 }
 
-void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -5258,8 +5535,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     const uint32_t kmask1 = 0x03030303;
     const uint32_t kmask2 = 0x0f0f0f0f;
 
-    const block_q3_K * restrict x = vx;
-    const block_q8_K * restrict y = vy;
+    const block_q3_K * GGML_RESTRICT x = vx;
+    const block_q8_K * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -5284,9 +5561,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict q3_sv = x[i].qs;
-        const uint8_t * restrict qh_sv = x[i].hmask;
-        const int8_t  * restrict q8_sv = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3_sv = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh_sv = x[i].hmask;
+        const int8_t  * GGML_RESTRICT q8_sv = y[i].qs;
 
         // Set up scales
         memcpy(aux, x[i].scales, 12);
@@ -5460,9 +5737,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict q3 = x[i].qs;
-        const uint8_t * restrict qh = x[i].hmask;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].hmask;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
 
@@ -5546,8 +5823,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict q3 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         // Set up scales
         memcpy(aux, x[i].scales, 12);
@@ -5651,8 +5928,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict q3 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         // Set up scales
         aux = (const uint32_t *)x[i].scales;
@@ -5785,9 +6062,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
-        const uint8_t * restrict q3 = x[i].qs;
-        const uint8_t * restrict hm = x[i].hmask;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         // Process blocks with SIMD
         int8_t * a = aux8;
@@ -5871,97 +6148,221 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     uint32_t aux[3];
     uint32_t utmp[4];
 
+    const int vector_length = __riscv_vlenb() * 8;
     float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
 
-        const uint8_t * restrict q3 = x[i].qs;
-        const uint8_t * restrict qh = x[i].hmask;
-        const  int8_t * restrict q8 = y[i].qs;
+    switch (vector_length) {
+    case 256:
+        for (int i = 0; i < nb; ++i) {
 
-        memcpy(aux, x[i].scales, 12);
-        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
-        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
-        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
-        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+            const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+            const uint8_t * GGML_RESTRICT qh = x[i].hmask;
+            const  int8_t * GGML_RESTRICT q8 = y[i].qs;
 
-        int8_t * scale = (int8_t *)utmp;
-        for (int j = 0; j < 16; ++j) scale[j] -= 32;
+            memcpy(aux, x[i].scales, 12);
+            utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+            utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+            utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+            utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+            int8_t * scale = (int8_t *)utmp;
+            for (int j = 0; j < 16; ++j) scale[j] -= 32;
 
 
-        size_t vl = 32;
-        uint8_t m =  1;
+            size_t vl = 32;
+            uint8_t m =  1;
 
-        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
-        vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
+            vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+            vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
 
-        int sum_t = 0;
+            int sum_t = 0;
 
-        for (int j = 0; j < QK_K; j += 128) {
+            for (int j = 0; j < QK_K; j += 128) {
 
-            vl = 32;
+                vl = 32;
 
-            // load Q3
-            vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
+                // load Q3
+                vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
 
-            vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
-            vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
-            vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
-            vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
+                vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
+                vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
+                vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
+                vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
 
-            // compute mask for subtraction
-            vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
-            vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
-            vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
-            m <<= 1;
+                // compute mask for subtraction
+                vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
+                vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
+                vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
+                m <<= 1;
 
-            vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
-            vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
-            vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
-            m <<= 1;
+                vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+                vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
+                vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
+                m <<= 1;
 
-            vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
-            vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
-            vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
-            m <<= 1;
+                vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+                vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
+                vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
+                m <<= 1;
 
-            vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
-            vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
-            vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
-            m <<= 1;
+                vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
+                vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
+                vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
+                m <<= 1;
 
-            // load Q8 and take product with Q3
-            vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
-            vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
-            vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
-            vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+                // load Q8 and take product with Q3
+                vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
+                vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+                vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+                vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
 
-            vl = 16;
+                vl = 16;
 
-            // retrieve lane to multiply with scale
-            vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
-            vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
-            vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
-            vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
-            vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
-            vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
-            vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
-            vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
+                // retrieve lane to multiply with scale
+                vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
+                vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
+                vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
+                vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
+                vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
+                vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
+                vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
+                vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
 
-            vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
-            vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
-            vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
-            vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
+                vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
+                vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
+                vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
+                vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
 
-            sum_t +=  __riscv_vmv_x_s_i32m1_i32(isum3);
+                sum_t +=  __riscv_vmv_x_s_i32m1_i32(isum3);
 
-            q3 += 32;    q8 += 128;   scale += 8;
+                q3 += 32;    q8 += 128;   scale += 8;
+
+            }
+
+            const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+            sumf += d*sum_t;
 
         }
+        break;
+    case 128:
+        for (int i = 0; i < nb; ++i) {
+            const uint8_t * restrict q3 = x[i].qs;
+            const uint8_t * restrict qh = x[i].hmask;
+            const  int8_t * restrict q8 = y[i].qs;
 
-        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+            int8_t * scale = (int8_t *)utmp;
+            int tmp;
+            __asm__ __volatile__(
+                "vsetivli zero, 12, e8, m1\n\t"
+                "vle8.v v0, (%[s6b])\n\t"
+                "vmv1r.v v2, v0\n\t"
+                "vsetivli zero, 2, e64, m1\n\t"
+                "vmv.v.x v9, %[sh]\n\t"\
+                "vslidedown.vi v1, v0, 1\n\t"
+                "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
+                "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
+                "vsetivli zero, 4, e32, m1\n\t"
+                "vid.v v9\n\t"
+                "vmv.x.s %[tmp], v1\n\t"
+                "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
+                "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
+                "vsrl.vv v4, v1, v9\n\t"
+                "vsrl.vv v2, v0, v8\n\t"
+                "vand.vx v5, v4, %[kmask1]\n\t"
+                "vand.vx v3, v2, %[kmask2]\n\t"
+                "vsll.vi v6, v5, 4\n\t"
+                "vor.vv v7, v6, v3\n\t"
+                "vsetivli zero, 16, e8, m1\n\t"
+                "vsub.vx v0, v7, %[c]\n\t"
+                "vse8.v v0, (%[scale])"
+                : [tmp] "=&r" (tmp)
+                : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
+                , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
+                : "memory"
+                , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
+                , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
+                , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
+                , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+            );
 
-        sumf += d*sum_t;
+            uint8_t m = 1;
+            int isum = 0;
+            for (int j = 0; j < QK_K; j += 128) {
+                __asm__ __volatile__(
+                    "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
+                    "vle8.v v8, (%[q3])\n\t"
+                    "vsrl.vi v10, v8, 2\n\t"
+                    "vsrl.vi v12, v8, 4\n\t"
+                    "vsrl.vi v14, v8, 6\n\t"
+                    "vand.vi v8, v8, 3\n\t"
+                    "vand.vi v10, v10, 3\n\t"
+                    "vand.vi v12, v12, 3\n\t"
+                    "vle8.v v2, (%[qh])\n\t"
+                    "vand.vx v4, v2, %[m]\n\t"
+                    "slli %[m], %[m], 1\n\t"
+                    "vmseq.vx v0, v4, zero\n\t"
+                    "vadd.vi v8, v8, -4, v0.t\n\t"
+                    "vand.vx v4, v2, %[m]\n\t"
+                    "slli %[m], %[m], 1\n\t"
+                    "vmseq.vx v0, v4, zero\n\t"
+                    "vadd.vi v10, v10, -4, v0.t\n\t"
+                    "vand.vx v4, v2, %[m]\n\t"
+                    "slli %[m], %[m], 1\n\t"
+                    "vmseq.vx v0, v4, zero\n\t"
+                    "vadd.vi v12, v12, -4, v0.t\n\t"
+                    "vand.vx v4, v2, %[m]\n\t"
+                    "slli %[m], %[m], 1\n\t"
+                    "vmseq.vx v0, v4, zero\n\t"
+                    "vadd.vi v14, v14, -4, v0.t\n\t"
+                    "vsetvli zero, %[vl128], e8, m8\n\t"
+                    "vle8.v v0, (%[q8])\n\t"
+                    "vsetvli zero, %[vl64], e8, m4\n\t"
+                    "vwmul.vv v16, v0, v8\n\t"
+                    "vwmul.vv v24, v4, v12\n\t"
+                    "vsetivli zero, 16, e16, m2\n\t"
+                    "vmv.v.x v0, zero\n\t"
+                    "vwredsum.vs v10, v16, v0\n\t"
+                    "vwredsum.vs v9, v18, v0\n\t"
+                    "vwredsum.vs v8, v20, v0\n\t"
+                    "vwredsum.vs v7, v22, v0\n\t"
+                    "vwredsum.vs v11, v24, v0\n\t"
+                    "vwredsum.vs v12, v26, v0\n\t"
+                    "vwredsum.vs v13, v28, v0\n\t"
+                    "vwredsum.vs v14, v30, v0\n\t"
+                    "vsetivli zero, 4, e32, m1\n\t"
+                    "vslideup.vi v10, v9, 1\n\t"
+                    "vslideup.vi v8, v7, 1\n\t"
+                    "vslideup.vi v11, v12, 1\n\t"
+                    "vslideup.vi v13, v14, 1\n\t"
+                    "vslideup.vi v10, v8, 2\n\t"
+                    "vslideup.vi v11, v13, 2\n\t"
+                    "vsetivli zero, 8, e32, m2\n\t"\
+                    "vle8.v v15, (%[scale])\n\t"
+                    "vsext.vf4 v12, v15\n\t"
+                    "vmul.vv v10, v10, v12\n\t"
+                    "vredsum.vs v0, v10, v0\n\t"
+                    "vmv.x.s %[tmp], v0\n\t"
+                    "add %[isum], %[isum], %[tmp]"
+                    : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
+                    : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
+                    , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
+                    : "memory"
+                    , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
+                    , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
+                    , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
+                    , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+                );
+                q3 += 32;    q8 += 128;   scale += 8;
+            }
 
+            const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+            sumf += d * isum;
+        }
+        break;
+    default:
+        assert(false && "Unsupported vector length");
+        break;
     }
 
     *s = sumf;
@@ -6016,8 +6417,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         vector signed int vsumi6 = v0;
         vector signed int vsumi7 = v0;
 
-        const uint8_t * restrict q3 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         for (int j = 0; j < QK_K/128; ++j) {
             __builtin_prefetch(q3, 0, 1);
@@ -6130,8 +6531,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     for (int i = 0; i < nb; ++i) {
 
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-        const uint8_t * restrict q3 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
         // Set up scales
         memcpy(aux, x[i].scales, 12);
         __m128i scales128 = lsx_set_w(
@@ -6216,11 +6617,11 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
-        const uint8_t * restrict q3 = x[i].qs;
-        const uint8_t * restrict hm = x[i].hmask;
-        const  int8_t * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
+        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
         memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * restrict a = aux8;
+        int8_t * GGML_RESTRICT a = aux8;
         uint8_t m = 1;
         for (int j = 0; j < QK_K; j += 128) {
             for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
@@ -6263,7 +6664,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
 }
 
-void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -6271,8 +6672,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     UNUSED(by);
     UNUSED(bs);
 
-    const block_q4_K * restrict x = vx;
-    const block_q8_K * restrict y = vy;
+    const block_q4_K * GGML_RESTRICT x = vx;
+    const block_q8_K * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -6307,8 +6708,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const uint8_t * scales = (const uint8_t *)utmp;
 
-        const uint8_t * restrict q4 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         const int vector_length = ggml_cpu_get_sve_cnt()*8;
         const svuint8_t m4b = svdup_n_u8(0xf);
@@ -6395,8 +6796,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const uint8_t * scales = (const uint8_t *)utmp;
 
-        const uint8_t * restrict q4 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         int32_t sumi1 = 0;
         int32_t sumi2 = 0;
@@ -6434,8 +6835,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
         const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
 
-        const uint8_t * restrict q4 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         // Process scales and mins
         memcpy(utmp, x[i].scales, 12);
@@ -6447,7 +6848,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         // Sum mins * q8sums
         int32_t sumi = 0;
-        const int16_t * restrict q8sums = y[i].bsums;
+        const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
         const uint8_t * m = (const uint8_t *)&utmp[2];
         for (int j = 0; j < 16; j += 2) {
             sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
@@ -6546,8 +6947,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         utmp[2] = uaux;
         utmp[0] &= kmask1;
 
-        const uint8_t * restrict q4 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
 
@@ -6605,8 +7006,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
         const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 
-        const uint8_t * restrict q4 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         memcpy(utmp, x[i].scales, 12);
         utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
@@ -6679,69 +7080,181 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     const uint8_t * scales = (const uint8_t*)&utmp[0];
     const uint8_t * mins   = (const uint8_t*)&utmp[2];
 
+    const int vector_length = __riscv_vlenb() * 8;
     float sumf = 0;
 
-    for (int i = 0; i < nb; ++i) {
+    switch (vector_length) {
+    case 256:
+        for (int i = 0; i < nb; ++i) {
 
-        size_t vl = 8;
+            size_t vl = 8;
 
-        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+            const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+            const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 
-        vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
-        vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
-        vint16mf2_t q8sums   = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+            vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+            vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+            vint16mf2_t q8sums   = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
 
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
+            memcpy(utmp, x[i].scales, 12);
+            utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+            const uint32_t uaux = utmp[1] & kmask1;
+            utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+            utmp[2] = uaux;
+            utmp[0] &= kmask1;
 
-        vuint8mf4_t mins8  = __riscv_vle8_v_u8mf4(mins, vl);
-        vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
-        vint32m1_t  prod   = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+            vuint8mf4_t mins8  = __riscv_vle8_v_u8mf4(mins, vl);
+            vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+            vint32m1_t  prod   = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
 
-        vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
-        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+            vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+            sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
 
-        const uint8_t * restrict q4 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+            const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+            const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
-        vl = 32;
+            vl = 32;
 
-        int32_t sum_1 = 0;
-        int32_t sum_2 = 0;
+            int32_t sum_1 = 0;
+            int32_t sum_2 = 0;
 
-        vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+            vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
 
-        for (int j = 0; j < QK_K/64; ++j) {
-            // load Q4
-            vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
+            for (int j = 0; j < QK_K/64; ++j) {
+                // load Q4
+                vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
 
-            // load Q8 and multiply it with lower Q4 nibble
-            vint8m1_t  q8_0 = __riscv_vle8_v_i8m1(q8, vl);
-            vint8m1_t  q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
-            vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
-            vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
+                // load Q8 and multiply it with lower Q4 nibble
+                vint8m1_t  q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+                vint8m1_t  q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
+                vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
+                vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
 
-            sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
+                sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
 
-            // load Q8 and multiply it with upper Q4 nibble
-            vint8m1_t  q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
-            vint8m1_t  q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
-            vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
-            vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
+                // load Q8 and multiply it with upper Q4 nibble
+                vint8m1_t  q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+                vint8m1_t  q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
+                vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
+                vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
 
-            sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
+                sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
 
-            q4 += 32;    q8 += 64;
+                q4 += 32;    q8 += 64;
+
+            }
+
+            sumf += d*(sum_1 + sum_2);
 
         }
+        break;
+    case 128:
+        for (int i = 0; i < nb; ++i) {
+            const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+            const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 
-        sumf += d*(sum_1 + sum_2);
+            int tmp, tmp2, sumi;
+            __asm__ __volatile__(
+                "vsetivli zero, 12, e8, m1\n\t"
+                "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
+                "vsetivli zero, 4, e32, m1\n\t"
+                "vslidedown.vi v2, v1, 2\n\t"
+                "vmv1r.v v3, v2\n\t"
+                "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
+                "vsetivli zero, 2, e32, m1\n\t"
+                "vmv.v.i v4, 4\n\t"
+                "vand.vx v8, v1, %[kmask1]\n\t"
+                "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
+                "vsrl.vi v6, v1, 6\n\t"
+                "vsrl.vv v7, v2, v5\n\t"
+                "vand.vx v0, v6, %[kmask3]\n\t"
+                "vand.vx v2, v7, %[kmask2]\n\t"
+                "vsll.vi v6, v0, 4\n\t"
+                "li %[t2], 8\n\t"
+                "addi %[t1], %[utmp], 4\n\t"
+                "vor.vv v1, v6, v2\n\t"
+                "vsse32.v v8, (%[utmp]), %[t2]\n\t"
+                "vsse32.v v1, (%[t1]), %[t2]\n\t"
+                "vsetivli zero, 8, e16, m1\n\t"
+                "vle32.v v2, (%[bsums])\n\t"
+                "vnsrl.wi v0, v2, 0\n\t"
+                "vnsrl.wi v1, v2, 16\n\t"
+                "vadd.vv v2, v0, v1\n\t"
+                "vle8.v v3, (%[mins])\n\t"
+                "vzext.vf2 v4, v3\n\t"
+                "vwmul.vv v6, v4, v2\n\t"
+                "vmv.v.x v0, zero\n\t"
+                "vsetivli zero, 8, e32, m2\n\t"
+                "vredsum.vs v0, v6, v0\n\t"
+                "vmv.x.s %[sumi], v0"
+                : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
+                : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
+                , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
+                , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
+                : "memory"
+                , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
+                , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
+                , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
+                , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+            );
+            sumf -= dmin * sumi;
 
+            const uint8_t * restrict q4 = x[i].qs;
+            const int8_t  * restrict q8 = y[i].qs;
+
+            sumi = 0;
+            const uint8_t * scale = scales;
+
+            for (int j = 0; j < QK_K/128; ++j) {
+                int vl128 = 128, vl64 = 64, vl32 = 32;
+                __asm__ __volatile__(
+                    "vsetvli zero, %[vl128], e8, m8\n\t"
+                    "vle8.v v8, (%[q8])\n\t"
+                    "vsetvli zero, %[vl64], e8, m4\n\t"
+                    "vle8.v v0, (%[q4])\n\t"
+                    "vsrl.vi v4, v0, 4\n\t"
+                    "vand.vi v0, v0, 0xF\n\t"
+                    "vsetvli zero, %[vl32], e8, m2\n\t"
+                    "vwmul.vv v28, v6, v14\n\t"
+                    "vwmul.vv v20, v4, v10\n\t"
+                    "vwmul.vv v24, v2, v12\n\t"
+                    "vwmul.vv v16, v0, v8\n\t"
+                    "vsetivli zero, 4, e32, m1\n\t"
+                    "vle8.v v2, (%[scale])\n\t"
+                    "vmv.v.x v0, zero\n\t"
+                    "vzext.vf4 v1, v2\n\t"
+                    "vsetvli zero, %[vl32], e16, m4\n\t"
+                    "vwredsum.vs v6, v24, v0\n\t"
+                    "vwredsum.vs v7, v28, v0\n\t"
+                    "vwredsum.vs v4, v16, v0\n\t"
+                    "vwredsum.vs v5, v20, v0\n\t"
+                    "vsetivli zero, 4, e32, m1\n\t"
+                    "vslideup.vi v6, v7, 1\n\t"
+                    "vslideup.vi v4, v5, 1\n\t"
+                    "vslideup.vi v4, v6, 2\n\t"
+                    "vmul.vv v8, v4, v1\n\t"
+                    "vredsum.vs v0, v8, v0\n\t"
+                    "vmv.x.s %[tmp], v0\n\t"
+                    "add %[sumi], %[sumi], %[tmp]"
+                    : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
+                    : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
+                    , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
+                    : "memory"
+                    , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
+                    , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
+                    , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
+                    , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+                );
+
+                q4 += 64;    q8 += 128;    scale += 4;
+            }
+
+            sumf += d * sumi;
+        }
+        break;
+    default:
+        assert(false && "Unsupported vector length");
+        break;
     }
 
     *s = sumf;
@@ -6808,8 +7321,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         vector signed int vsumi2 = v0;
         vector signed int vsumi3 = v0;
 
-        const uint8_t * restrict q4 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         for (int j = 0; j < QK_K/64; j+=2) {
             __builtin_prefetch(q4, 0, 1);
@@ -6900,8 +7413,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         utmp[2] = uaux;
         utmp[0] &= kmask1;
 
-        const uint8_t * restrict q4 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
         const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
@@ -6983,8 +7496,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
 
         const uint8_t * scales = (const uint8_t *)utmp;
-        const uint8_t * restrict x0 = x[i].qs;
-        const int8_t  * restrict y0 = y[i].qs;
+        const uint8_t * GGML_RESTRICT x0 = x[i].qs;
+        const int8_t  * GGML_RESTRICT y0 = y[i].qs;
 
         int32_t sumi1 = 0;
         int32_t sumi2 = 0;
@@ -7032,10 +7545,10 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
-        const uint8_t * restrict q4 = x[i].qs;
-        const  int8_t * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
         memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * restrict a = aux8;
+        int8_t * GGML_RESTRICT a = aux8;
         for (int j = 0; j < QK_K/64; ++j) {
             for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
             a += 32;
@@ -7078,7 +7591,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 #endif
 }
 
-void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy,  size_t by, int nrc) {
+void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -7086,8 +7599,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     UNUSED(by);
     UNUSED(bs);
 
-    const block_q5_K * restrict x = vx;
-    const block_q8_K * restrict y = vy;
+    const block_q5_K * GGML_RESTRICT x = vx;
+    const block_q8_K * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -7129,9 +7642,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const uint8_t * scales = (const uint8_t *)utmp;
 
-        const uint8_t * restrict q5 = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
 
@@ -7176,8 +7689,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     float summs = 0.f;
 
     for (int i = 0; i < nb; ++i) {
-        const uint8_t * restrict q5 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
         const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
@@ -7260,8 +7773,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
         const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 
-        const uint8_t * restrict q5 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         memcpy(utmp, x[i].scales, 12);
         utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
@@ -7352,9 +7865,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
         const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
 
-        const uint8_t * restrict q5 = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         // Process scales and mins
         memcpy(utmp, x[i].scales, 12);
@@ -7366,7 +7879,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         // Sum mins * q8sums
         int32_t sumi_mins = 0;
-        const int16_t * restrict q8sums = y[i].bsums;
+        const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
         const uint8_t * m = (const uint8_t *)&utmp[2];
         for (int j = 0; j < 16; j += 2) {
             sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
@@ -7470,16 +7983,16 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         vl = 8;
 
-        const uint8_t * restrict q5 = x[i].qs;
-        const uint8_t * restrict hm = x[i].qh;
-        const  int8_t * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
+        const uint8_t * GGML_RESTRICT hm = x[i].qh;
+        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
 
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
         const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
 
-        vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
-        vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
-        vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+        vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);
+        vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);
+        vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);
 
         memcpy(utmp, x[i].scales, 12);
         utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
@@ -7488,11 +8001,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         utmp[2] = uaux;
         utmp[0] &= kmask1;
 
-        vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
-        vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
-        vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+        vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);
+        vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
+        vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);
 
-        vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+        vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
         sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
 
         vl = 32;
@@ -7501,43 +8014,42 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         uint8_t m = 1;
         vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
-        vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
+        vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);
 
         for (int j = 0; j < QK_K/64; ++j) {
             // load Q5 and Q8
-            vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
-            vint8m1_t  q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
-            vint8m1_t  q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
+            vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);
+            vint8m2_t  q8_y1 = __riscv_vle8_v_i8m2(q8, vl);
+            vint8m2_t  q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);
 
             // compute mask for addition
-            vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
-            vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
-            vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
-            vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl);
+            vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));
+            vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);
+            vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);
+            vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);
             m <<= 1;
 
-            vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
-            vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
-            vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
-            vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl);
+            vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));
+            vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);
+            vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);
+            vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);
             m <<= 1;
 
-            vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
-            vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
+            vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);
+            vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);
 
-            vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
-            vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
+            vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);
+            vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);
 
-            vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
-            vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
+            vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);
+            vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);
 
-            aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
+            aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);
             q5 += 32;    q8 += 64;
 
         }
 
-        vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
-        sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
+        sums += aux32 * d;
 
     }
 
@@ -7611,8 +8123,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         vector signed int vsumi2 = v0;
         vector signed int vsumi3 = v0;
 
-        const uint8_t * restrict q5 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         for (int j = 0; j < QK_K/64; ++j) {
             __builtin_prefetch(q5, 0, 1);
@@ -7684,8 +8196,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     for (int i = 0; i < nb; ++i) {
 
-        const uint8_t * restrict q5 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
         const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
@@ -7794,9 +8306,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
 
         const uint8_t * scales = (const uint8_t *)utmp;
-        const uint8_t * restrict x0l = x[i].qs;
-        const uint8_t * restrict x0h = x[i].qh;
-        const int8_t  * restrict y0 = y[i].qs;
+        const uint8_t * GGML_RESTRICT x0l = x[i].qs;
+        const uint8_t * GGML_RESTRICT x0h = x[i].qh;
+        const int8_t  * GGML_RESTRICT y0 = y[i].qs;
 
         v_xh[0] = vec_xl(0 , x0h);
         v_xh[1] = vec_xl(16, x0h);
@@ -7849,11 +8361,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
-        const uint8_t * restrict q4 = x[i].qs;
-        const uint8_t * restrict hm = x[i].qh;
-        const  int8_t * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+        const uint8_t * GGML_RESTRICT hm = x[i].qh;
+        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
         memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * restrict a = aux8;
+        int8_t * GGML_RESTRICT a = aux8;
         uint8_t m = 1;
         for (int j = 0; j < QK_K/64; ++j) {
             for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
@@ -7900,7 +8412,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 #endif
 }
 
-void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -7908,12 +8420,161 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     UNUSED(by);
     UNUSED(bs);
 
-    const block_q6_K * restrict x = vx;
-    const block_q8_K * restrict y = vy;
+    const block_q6_K * GGML_RESTRICT x = vx;
+    const block_q8_K * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
-#ifdef __ARM_NEON
+#ifdef __ARM_FEATURE_SVE
+    const int vector_length = ggml_cpu_get_sve_cnt()*8;
+    float sum = 0;
+    svuint8_t m4b = svdup_n_u8(0xf);
+    svint32_t vzero = svdup_n_s32(0);
+    svuint8_t mone = svdup_n_u8(0x30);
+    svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
+    svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * GGML_RESTRICT q6 = x[i].ql;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+
+        const int8_t * GGML_RESTRICT scale = x[i].scales;
+
+        const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
+        const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
+        const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
+        const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
+        const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
+        const svint64_t prod = svdup_n_s64(0);
+        int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
+                                                                                 svdot_s64(prod, q8sums_2, q6scales_2)));
+        int32_t isum = 0;
+
+        switch (vector_length) {
+            case 128:
+                {
+                    const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
+                    const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
+                    svint32_t isum_tmp = svdup_n_s32(0);
+                    for (int j = 0; j < QK_K/128; ++j) {
+                        svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
+                        svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
+                        qh += 32;
+                        svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
+                        svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
+                        svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
+                        svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
+                        q6 += 64;
+                        svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
+                        svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
+                        svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
+                        svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
+                        q8 += 64;
+
+                        q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
+                        q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
+                        q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
+                        q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
+                        q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
+                        q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
+                        q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
+                        q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
+                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
+                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
+                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
+                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
+
+                        scale += 4;
+                        q8bytes_1 = svld1_s8(pg8_16, q8);
+                        q8bytes_2 = svld1_s8(pg8_16, q8+16);
+                        q8bytes_3 = svld1_s8(pg8_16, q8+32);
+                        q8bytes_4 = svld1_s8(pg8_16, q8+48);
+                        q8 += 64;
+
+                        q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
+                        q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
+                        q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
+                        q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
+                        q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
+                        q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
+                        q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
+                        q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
+                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
+                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
+                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
+                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
+                        scale += 4;
+                    }
+                    isum += svaddv_s32(pg32_4, isum_tmp);
+                    sum += d_all * y[i].d * (isum - 32 * isum_mins);
+                }
+                break;
+            case 256:
+            case 512:
+                {
+                    const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
+                    const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
+                    const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
+                    svint32_t isum_tmp = svdup_n_s32(0);
+                    for (int j = 0; j < QK_K/128; j++) {
+                        svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
+                        qh += 32;
+                        svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
+                        svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
+                        q6 += 64;
+                        svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
+                        svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
+                        svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
+                        svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
+                        q8 += 128;
+                        q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
+                        q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
+                        q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
+                        q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
+                        q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
+                        q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
+                        q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
+                        q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
+
+                        svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
+                        scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
+                        scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
+                        svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
+                        scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
+                        scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
+                        svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
+                        scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
+                        scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
+                        svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
+                        scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
+                        scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
+                        svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
+                        svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
+                        svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
+                        svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
+
+                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
+                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
+                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
+                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
+                        scale += 8;
+                    }
+                    isum += svaddv_s32(pg32_8, isum_tmp);
+                    sum += d_all * y[i].d * (isum - 32 * isum_mins);
+                }
+                break;
+            default:
+                assert(false && "Unsupported vector length");
+                break;
+        }
+    }
+
+    *s = sum;
+
+#elif __ARM_NEON
     float sum = 0;
 
     const uint8x16_t m4b = vdupq_n_u8(0xF);
@@ -7929,11 +8590,11 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const float d_all = GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict q6 = x[i].ql;
-        const uint8_t * restrict qh = x[i].qh;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q6 = x[i].ql;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
-        const int8_t * restrict scale = x[i].scales;
+        const int8_t * GGML_RESTRICT scale = x[i].scales;
 
         const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
         const int8x16_t scales = vld1q_s8(scale);
@@ -8020,9 +8681,9 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict q4 = x[i].ql;
-        const uint8_t * restrict qh = x[i].qh;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
 
@@ -8098,9 +8759,9 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict q4 = x[i].ql;
-        const uint8_t * restrict qh = x[i].qh;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         // handle the q6_k -32 offset separately using bsums
         const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
@@ -8199,8 +8860,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     for (int i = 0; i < nb; ++i) {
         // Unpack 6-bit quantized data into aux8 (unchanged)
-        const uint8_t * restrict q4 = x[i].ql;
-        const uint8_t * restrict qh = x[i].qh;
+        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
         int8_t * a = aux8;
         for (int j = 0; j < QK_K; j += 128) {
             for (int l = 0; l < 32; ++l) {
@@ -8214,8 +8875,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
             qh += 32;
         }
 
-        const int8_t * restrict a_ptr = aux8;
-        const int8_t * restrict q8 = y[i].qs;
+        const int8_t * GGML_RESTRICT a_ptr = aux8;
+        const int8_t * GGML_RESTRICT q8 = y[i].qs;
         v128_t acc0 = wasm_i32x4_splat(0);
         v128_t acc1 = wasm_i32x4_splat(0);
 
@@ -8273,85 +8934,168 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
 #elif defined __riscv_v_intrinsic
 
+    const int vector_length = __riscv_vlenb() * 8;
     float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
 
-        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+    switch (vector_length) {
+    case 256:
+        for (int i = 0; i < nb; ++i) {
 
-        const uint8_t * restrict q6 = x[i].ql;
-        const uint8_t * restrict qh = x[i].qh;
-        const  int8_t * restrict q8 = y[i].qs;
+            const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
 
-        const int8_t * restrict scale = x[i].scales;
+            const uint8_t * GGML_RESTRICT q6 = x[i].ql;
+            const uint8_t * GGML_RESTRICT qh = x[i].qh;
+            const  int8_t * GGML_RESTRICT q8 = y[i].qs;
 
-        size_t vl;
+            const int8_t * GGML_RESTRICT scale = x[i].scales;
 
-        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+            size_t vl;
 
-        int sum_t = 0;
-        int is = 0;
+            vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
 
-        for (int j = 0; j < QK_K/128; ++j) {
+            int sum_t = 0;
+            int is = 0;
 
-            vl = 32;
+            for (int j = 0; j < QK_K/128; ++j) {
 
-            // load qh
-            vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
+                vl = 32;
 
-            // load Q6
-            vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
-            vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
+                // load qh
+                vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
 
-            vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
-            vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
-            vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
-            vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
+                // load Q6
+                vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
+                vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
 
-            vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
-            vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
-            vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
-            vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
+                vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
+                vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
+                vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
+                vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
 
-            vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
-            vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
-            vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
-            vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
+                vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
+                vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
+                vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
+                vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
 
-            vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
-            vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
-            vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
-            vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
+                vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
+                vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
+                vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
+                vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
 
-            // load Q8 and take product
-            vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
-            vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
-            vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
-            vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+                vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
+                vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
+                vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
+                vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
 
-            vl = 16;
+                // load Q8 and take product
+                vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
+                vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+                vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+                vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
 
-            vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
-            vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
-            vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
-            vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
-            vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
-            vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
-            vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
-            vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
+                vl = 16;
 
-            vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
-            vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
-            vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
-            vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
+                vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
+                vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
+                vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
+                vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
+                vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
+                vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
+                vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
+                vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
 
-            sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+                vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
+                vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
+                vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
+                vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
 
-            q6 += 64;   qh += 32;   q8 += 128;   is=8;
+                sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+
+                q6 += 64;   qh += 32;   q8 += 128;   is=8;
+
+            }
+
+            sumf += d * sum_t;
 
         }
+        break;
+    case 128:
+        for (int i = 0; i < nb; ++i) {
 
-        sumf += d * sum_t;
+            const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
 
+            const uint8_t * restrict q6 = x[i].ql;
+            const uint8_t * restrict qh = x[i].qh;
+            const  int8_t * restrict q8 = y[i].qs;
+
+            const int8_t * restrict scale = x[i].scales;
+
+            int sum_t = 0;
+            int t0;
+
+            for (int j = 0; j < QK_K/128; ++j) {
+                __asm__ __volatile__(
+                    "vsetvli zero, %[vl32], e8, m2\n\t"
+                    "vle8.v v4, (%[qh])\n\t"
+                    "vsll.vi v0, v4, 4\n\t"
+                    "vsll.vi v2, v4, 2\n\t"
+                    "vsrl.vi v6, v4, 2\n\t"
+                    "vsetvli zero, %[vl64], e8, m4\n\t"
+                    "vle8.v v8, (%[q6])\n\t"
+                    "vsrl.vi v12, v8, 4\n\t"
+                    "vand.vi v8, v8, 0xF\n\t"
+                    "vsetvli zero, %[vl128], e8, m8\n\t"
+                    "vand.vx v0, v0, %[mask]\n\t"
+                    "vor.vv v8, v8, v0\n\t"
+                    "vle8.v v0, (%[q8])\n\t"
+                    "vsub.vx v8, v8, %[vl32]\n\t"
+                    "vsetvli zero, %[vl64], e8, m4\n\t"
+                    "vwmul.vv v16, v0, v8\n\t"
+                    "vwmul.vv v24, v4, v12\n\t"
+                    "vsetivli zero, 16, e16, m2\n\t"
+                    "vmv.v.x v0, zero\n\t"
+                    "vwredsum.vs v10, v16, v0\n\t"
+                    "vwredsum.vs v9, v18, v0\n\t"
+                    "vwredsum.vs v8, v20, v0\n\t"
+                    "vwredsum.vs v7, v22, v0\n\t"
+                    "vwredsum.vs v11, v24, v0\n\t"
+                    "vwredsum.vs v12, v26, v0\n\t"
+                    "vwredsum.vs v13, v28, v0\n\t"
+                    "vwredsum.vs v14, v30, v0\n\t"
+                    "vsetivli zero, 4, e32, m1\n\t"
+                    "vslideup.vi v10, v9, 1\n\t"
+                    "vslideup.vi v8, v7, 1\n\t"
+                    "vslideup.vi v11, v12, 1\n\t"
+                    "vslideup.vi v13, v14, 1\n\t"
+                    "vslideup.vi v10, v8, 2\n\t"
+                    "vslideup.vi v11, v13, 2\n\t"
+                    "vsetivli zero, 8, e32, m2\n\t"
+                    "vle8.v v2, (%[scale])\n\t"
+                    "vsext.vf4 v4, v2\n\t"
+                    "vmul.vv v2, v4, v10\n\t"
+                    "vredsum.vs v0, v2, v0\n\t"
+                    "vmv.x.s %[t0], v0\n\t"
+                    "add %[sumi], %[sumi], %[t0]"
+                    : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
+                    : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
+                    , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
+                    , [mask] "r" (0x30)
+                    : "memory"
+                    , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
+                    , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
+                    , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
+                    , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+                );
+                q6 += 64;   qh += 32;   q8 += 128;   scale += 8;
+            }
+
+            sumf += d * sum_t;
+
+        }
+        break;
+    default:
+        assert(false && "Unsupported vector length");
+        break;
     }
 
     *s = sumf;
@@ -8384,10 +9128,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         vector signed int vsumi6 = v0;
         vector signed int vsumi7 = v0;
 
-        const uint8_t * restrict q6 = x[i].ql;
-        const uint8_t * restrict qh = x[i].qh;
-        const int8_t  * restrict qs = x[i].scales;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q6 = x[i].ql;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t  * GGML_RESTRICT qs = x[i].scales;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         for (int j = 0; j < QK_K/128; ++j) {
             __builtin_prefetch(q6, 0, 0);
@@ -8503,9 +9247,9 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict q4 = x[i].ql;
-        const uint8_t * restrict qh = x[i].qh;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
         const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
@@ -8571,11 +9315,11 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     for (int i = 0; i < nb; ++i) {
         const float d_all = GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict x0l = x[i].ql;
-        const uint8_t * restrict x0h = x[i].qh;
-        const int8_t  * restrict y0 = y[i].qs;
+        const uint8_t * GGML_RESTRICT x0l = x[i].ql;
+        const uint8_t * GGML_RESTRICT x0h = x[i].qh;
+        const int8_t  * GGML_RESTRICT y0 = y[i].qs;
 
-        const int8_t  * restrict scale = x[i].scales;
+        const int8_t  * GGML_RESTRICT scale = x[i].scales;
 
         const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
         const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
@@ -8686,11 +9430,11 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
-        const uint8_t * restrict q4 = x[i].ql;
-        const uint8_t * restrict qh = x[i].qh;
-        const  int8_t * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
         memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * restrict a = aux8;
+        int8_t * GGML_RESTRICT a = aux8;
         for (int j = 0; j < QK_K; j += 128) {
             for (int l = 0; l < 32; ++l) {
                 a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
@@ -8758,7 +9502,7 @@ static const int8_t keven_signs_q2xs[1024] = {
 };
 #endif
 
-void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -8766,8 +9510,8 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     UNUSED(by);
     UNUSED(bs);
 
-    const block_iq2_xxs * restrict x = vx;
-    const block_q8_K    * restrict y = vy;
+    const block_iq2_xxs * GGML_RESTRICT x = vx;
+    const block_q8_K    * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -8785,8 +9529,8 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * restrict q2 = x[i].qs;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
         float sumf1 = 0, sumf2 = 0;
         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
             q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
@@ -8822,8 +9566,8 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     __m256 accumf = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * restrict q2 = x[i].qs;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
         __m256i sumi1 = _mm256_setzero_si256();
         __m256i sumi2 = _mm256_setzero_si256();
         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
@@ -8863,8 +9607,8 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     __m256 accumf = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * restrict q2 = x[i].qs;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
         __m128i sumi1_0 = _mm_setzero_si128();
         __m128i sumi1_1 = _mm_setzero_si128();
         __m128i sumi2_0 = _mm_setzero_si128();
@@ -8928,8 +9672,8 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
         vector signed int vsumi2 = v0;
         vector signed int vsumi3 = v0;
 
-        const uint16_t * restrict q2 = x[i].qs;
-        const int8_t  *  restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;
 
         for (int j = 0; j < QK_K/32; j += 2) {
             __builtin_prefetch(q2, 0, 1);
@@ -9005,8 +9749,8 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     __m256 accumf = (__m256)__lasx_xvldi(0);
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * restrict q2 = x[i].qs;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
         __m256i sumi1 = __lasx_xvldi(0);
         __m256i sumi2 = __lasx_xvldi(0);
         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
@@ -9046,8 +9790,8 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
 //
 //    for (int i = 0; i < nb; ++i) {
 //        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-//        const uint16_t * restrict q2 = x[i].qs;
-//        const int8_t   * restrict q8 = y[i].qs;
+//        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+//        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
 //
 //        float sumf1 = 0, sumf2 = 0;
 //
@@ -9095,8 +9839,8 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     float sumf = 0.f;
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * restrict q2 = x[i].qs;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
         int32_t bsum = 0;
         for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
             memcpy(aux32, q2, 2*sizeof(uint32_t));
@@ -9119,7 +9863,7 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
 #endif
 }
 
-void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -9127,8 +9871,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     UNUSED(by);
     UNUSED(bs);
 
-    const block_iq2_xs * restrict x = vx;
-    const block_q8_K   * restrict y = vy;
+    const block_iq2_xs * GGML_RESTRICT x = vx;
+    const block_q8_K   * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -9145,8 +9889,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * restrict q2 = x[i].qs;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
         const uint8x8_t scales8 = vld1_u8(x[i].scales);
         const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
         const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
@@ -9223,8 +9967,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     __m256 accumf = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * restrict q2 = x[i].qs;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
 
         memcpy(&aux64, x[i].scales, 8);
         __m128i stmp = _mm_set1_epi64x(aux64);
@@ -9344,8 +10088,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     __m256 accumf = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * restrict q2 = x[i].qs;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
 
         memcpy(&aux64, x[i].scales, 8);
         __m128i stmp = _mm_set1_epi64x(aux64);
@@ -9499,8 +10243,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     __m256 accumf = (__m256)__lasx_xvldi(0);
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * restrict q2 = x[i].qs;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
 
         memcpy(&aux64, x[i].scales, 8);
         __m128i stmp = __lsx_vreplgr2vr_d(aux64);
@@ -9597,9 +10341,9 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
         vector signed int vsumi2 = v0;
         vector signed int vsumi3 = v0;
 
-        const uint16_t * restrict q2 = x[i].qs;
-        const uint8_t  * restrict sc = x[i].scales;
-        const int8_t  *  restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const uint8_t  * GGML_RESTRICT sc = x[i].scales;
+        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;
 
         for (int j = 0; j < QK_K/64; ++j) {
             __builtin_prefetch(q2, 0, 1);
@@ -9669,9 +10413,9 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     float sumf = 0.f;
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * restrict q2 = x[i].qs;
-        const uint8_t  * restrict sc = x[i].scales;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const uint8_t  * GGML_RESTRICT sc = x[i].scales;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
         int32_t bsum = 0;
         for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
             const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
@@ -9704,7 +10448,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
 #endif
 }
 
-void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -9712,8 +10456,8 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
     UNUSED(by);
     UNUSED(bs);
 
-    const block_iq2_s * restrict x = vx;
-    const block_q8_K  * restrict y = vy;
+    const block_iq2_s * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -9739,10 +10483,10 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
 
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
 
-        const uint8_t * restrict qs = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         int sumi1 = 0, sumi2 = 0;
         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
@@ -9813,10 +10557,10 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
     __m256 accumf = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict qs = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         memcpy(&aux64, x[i].scales, 8);
         const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
@@ -9886,10 +10630,10 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
     __m256 accumf = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict qs = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         memcpy(&aux64, x[i].scales, 8);
         const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
@@ -9984,11 +10728,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
         vector signed int vsumi2 = v0;
         vector signed int vsumi3 = v0;
 
-        const uint8_t *  restrict q2 = x[i].qs;
-        const uint8_t *  restrict qh = x[i].qh;
-        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
-        const uint8_t *  restrict sc = x[i].scales;
-        const int8_t  *  restrict q8 = y[i].qs;
+        const uint8_t *  GGML_RESTRICT q2 = x[i].qs;
+        const uint8_t *  GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const uint8_t *  GGML_RESTRICT sc = x[i].scales;
+        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;
 
         for (int j = 0; j < QK_K/32; j += 2) {
             __builtin_prefetch(q2, 0, 1);
@@ -10085,10 +10829,10 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
     __m256 accumf = (__m256)__lasx_xvldi(0);
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict qs = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
         __m128i tmp1;
         memcpy(&aux64, x[i].scales, 8);
@@ -10182,7 +10926,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
 
 }
 
-void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -10190,8 +10934,8 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     UNUSED(by);
     UNUSED(bs);
 
-    const block_iq3_xxs * restrict x = vx;
-    const block_q8_K    * restrict y = vy;
+    const block_iq3_xxs * GGML_RESTRICT x = vx;
+    const block_q8_K    * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -10207,9 +10951,9 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict q3 = x[i].qs;
-        const uint8_t * restrict gas = x[i].qs + QK_K/4;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
         float sumf1 = 0, sumf2 = 0;
         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
             q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
@@ -10245,9 +10989,9 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     __m256 accumf = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict q3 = x[i].qs;
-        const uint8_t * restrict gas = x[i].qs + QK_K/4;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
         __m256i sumi1 = _mm256_setzero_si256();
         __m256i sumi2 = _mm256_setzero_si256();
         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
@@ -10290,9 +11034,9 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     __m256 accumf = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict q3 = x[i].qs;
-        const uint8_t * restrict gas = x[i].qs + QK_K/4;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
         __m128i sumi1_0 = _mm_setzero_si128();
         __m128i sumi1_1 = _mm_setzero_si128();
         __m128i sumi2_0 = _mm_setzero_si128();
@@ -10359,9 +11103,9 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
         vector signed int vsumi2 = v0;
         vector signed int vsumi3 = v0;
 
-        const uint8_t * restrict q3 = x[i].qs;
-        const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4);
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const uint32_t * GGML_RESTRICT signs = (const uint32_t *)(x[i].qs + QK_K/4);
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 
 #pragma GCC unroll 1
         for (int j = 0; j < QK_K/32; j += 2) {
@@ -10433,9 +11177,9 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     __m256 accumf = (__m256)__lasx_xvldi(0);
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict q3 = x[i].qs;
-        const uint8_t * restrict gas = x[i].qs + QK_K/4;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
         __m256i sumi1 = __lasx_xvldi(0);
         __m256i sumi2 = __lasx_xvldi(0);
         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
@@ -10478,9 +11222,9 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     float sumf = 0.f;
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict q3 = x[i].qs;
-        const uint8_t * restrict gas = x[i].qs + QK_K/4;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
         int32_t bsum = 0;
         for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
             memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
@@ -10505,7 +11249,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
 #endif
 }
 
-void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -10513,8 +11257,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
     UNUSED(by);
     UNUSED(bs);
 
-    const block_iq3_s * restrict x = vx;
-    const block_q8_K  * restrict y = vy;
+    const block_iq3_s * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -10551,10 +11295,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict qs = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
-        const int8_t   * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
 
         memcpy(scales32, x[i].scales, 4);
         scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
@@ -10633,10 +11377,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
     __m256 accumf = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict qs = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
         __m256i sumi1 = _mm256_setzero_si256();
         __m256i sumi2 = _mm256_setzero_si256();
         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
@@ -10718,10 +11462,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
     __m256 accumf = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict qs = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
         __m128i sumi1_0 = _mm_setzero_si128();
         __m128i sumi1_1 = _mm_setzero_si128();
         __m128i sumi2_0 = _mm_setzero_si128();
@@ -10819,11 +11563,11 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
         vector float vyd = vec_splats(y[i].d);
         vector float vd = vec_mul(vxd, vyd);
 
-        const uint8_t *  restrict q3 = x[i].qs;
-        const uint8_t *  restrict qh = x[i].qh;
-        const uint16_t * restrict signs = (const uint16_t *)(x[i].signs);
-        const uint8_t *  restrict sc = x[i].scales;
-        const int8_t  *  restrict q8 = y[i].qs;
+        const uint8_t *  GGML_RESTRICT q3 = x[i].qs;
+        const uint8_t *  GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].signs);
+        const uint8_t *  GGML_RESTRICT sc = x[i].scales;
+        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;
 
         vector signed int vsumi0 = v0;
         vector signed int vsumi1 = v0;
@@ -10930,10 +11674,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
     __m256 accumf = (__m256)__lasx_xvldi(0);
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict qs = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
         __m256i sumi1 = __lasx_xvldi(0);
         __m256i sumi2 = __lasx_xvldi(0);
         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
@@ -10991,10 +11735,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
     float sumf = 0.f;
     for (int i = 0; i < nb; ++i) {
         const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * restrict qs = x[i].qs;
-        const uint8_t * restrict qh = x[i].qh;
-        const uint8_t * restrict signs = x[i].signs;
-        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint8_t * GGML_RESTRICT signs = x[i].signs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
         int32_t bsum = 0;
         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
             const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
@@ -11046,7 +11790,7 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
 }
 #endif
 
-void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -11054,8 +11798,8 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void
     UNUSED(by);
     UNUSED(bs);
 
-    const block_iq1_s * restrict x = vx;
-    const block_q8_K  * restrict y = vy;
+    const block_iq1_s * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -11117,10 +11861,19 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void
         __m256i sumi = _mm256_setzero_si256();
         int sumi1 = 0;
         for (int ib = 0; ib < QK_K/32; ib += 2) {
+#ifdef __BMI2__
+            const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib], 0x700070007000700ULL);
+            const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib + 1], 0x700070007000700ULL);
+            const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
+            const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
+            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);
+            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);
+#else
             const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
                                                     iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
             const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
                                                     iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+#endif
             qs += 8;
             const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
             const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
@@ -11213,10 +11966,10 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void
         vector signed int vsumi3 = vec_splats((int32_t)0);
         vector signed int vsumi8 = vec_splats((int32_t)0);
 
-        const uint8_t  * restrict q1 = x[i].qs;
-        const uint16_t * restrict qh = x[i].qh;
-        const int8_t   * restrict q8 = y[i].qs;
-        const int16_t  * restrict qs = y[i].bsums;
+        const uint8_t  * GGML_RESTRICT q1 = x[i].qs;
+        const uint16_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
+        const int16_t  * GGML_RESTRICT qs = y[i].bsums;
 
         for (int j = 0; j < QK_K/32; j += 2) {
             __builtin_prefetch(q1, 0, 1);
@@ -11377,7 +12130,7 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void
 #endif
 }
 
-void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_iq1_m_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
     UNUSED(nrc);
@@ -11385,8 +12138,8 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
     UNUSED(by);
     UNUSED(bs);
 
-    const block_iq1_m * restrict x = vx;
-    const block_q8_K  * restrict y = vy;
+    const block_iq1_m * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -11466,6 +12219,10 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
 
     const __m256i mask = _mm256_set1_epi16(0x7);
     const __m256i mone = _mm256_set1_epi16(1);
+    const __m256i mone8 = _mm256_set1_epi8(1);
+    const __m256i mtwo8 = _mm256_set1_epi8(2);
+    // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half.
+    const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0);
 
     __m256 accum1 = _mm256_setzero_ps();
     __m256 accum2 = _mm256_setzero_ps();
@@ -11477,10 +12234,33 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
         const uint16_t * sc = (const uint16_t *)x[i].scales;
 
         scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+        // Extract 3-bit scales (16 values)
+        __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc);
+        scales = _mm256_srlv_epi64(scales, scales_shift);
+        scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone);
+
+        // Indices to repeat each scale 8 times.
+        __m256i scales_idx1 = _mm256_set1_epi16(0x0100);
+        __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8));
 
         __m256i sumi1 = _mm256_setzero_si256();
         __m256i sumi2 = _mm256_setzero_si256();
         for (int ib = 0; ib < QK_K/32; ib += 2) {
+#ifdef __BMI2__
+            const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL)
+                                       | _pdep_u64(*(const uint16_t*)(qh) & 0x7777, 0xf000f000f000f00ULL);
+            const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL)
+                                       | _pdep_u64(*(const uint16_t*)(qh + 2) & 0x7777, 0xf000f000f000f00ULL);
+            const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
+            const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
+            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);
+            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);
+
+            // Convert signs to bytes 0x81 (negative) or 0x01 (positive)
+            const uint64_t delta_sign = _pdep_u64(*(const uint32_t*)(qh) & 0x88888888, 0xf0f0f0f0f0f0f0f0ULL);
+            const __m256i delta1 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign)));
+            const __m256i delta2 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign >> 32)));
+#else
             const __m256i q1b_1 = _mm256_set_epi64x(
                     iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
                     iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
@@ -11489,11 +12269,6 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
                     iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
                     iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
             );
-            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
-            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
-
-            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
-            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
 
             const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
                                                      qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
@@ -11503,15 +12278,21 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
                                                      qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
                                                      qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
                                                      qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+#endif
+            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
 
-            const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
-            const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
+            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+            const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1));
+            const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2));
 
-            __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
-            __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
+            __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1);
+            __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2);
+
+            scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8);
+            scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8);
 
-            scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
-            scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
             const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
             const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
             const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
@@ -11667,7 +12448,7 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
 #endif
 }
 
-void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(nrc == 1);
     UNUSED(nrc);
     UNUSED(bx);
@@ -11676,8 +12457,8 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
     assert(n % QK4_NL == 0);
     static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
 
-    const block_iq4_nl * restrict x = vx;
-    const block_q8_0   * restrict y = vy;
+    const block_iq4_nl * GGML_RESTRICT x = vx;
+    const block_q8_0   * GGML_RESTRICT y = vy;
 
     const int nb = n / QK4_NL;
 
@@ -11852,8 +12633,8 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
     const uint8x16_t v_m = vec_splat_u8(0x0F);
 
     for (; ib < nb; ++ib) {
-        const block_iq4_nl * restrict x0 = &x[ib];
-        const block_q8_0   * restrict y0 = &y[ib];
+        const block_iq4_nl * GGML_RESTRICT x0 = &x[ib];
+        const block_q8_0   * GGML_RESTRICT y0 = &y[ib];
 
         const uint8x16_t v_x = vec_xl(0, x0->qs);
         int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
@@ -11881,7 +12662,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
     *s = sumf;
 }
 
-void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(nrc == 1);
     UNUSED(nrc);
     UNUSED(bx);
@@ -11889,8 +12670,8 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     UNUSED(bs);
     assert(n % QK_K == 0);
 
-    const block_iq4_xs * restrict x = vx;
-    const block_q8_K   * restrict y = vy;
+    const block_iq4_xs * GGML_RESTRICT x = vx;
+    const block_q8_K   * GGML_RESTRICT y = vy;
 
     const int nb = n / QK_K;
 
@@ -12047,9 +12828,9 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
 
         uint16_t h = x[ibl].scales_h;
 
-        const uint8_t * restrict q4 = x[ibl].qs;
-        const uint8_t * restrict sc = x[ibl].scales_l;
-        const int8_t  * restrict q8 = y[ibl].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[ibl].qs;
+        const uint8_t * GGML_RESTRICT sc = x[ibl].scales_l;
+        const int8_t  * GGML_RESTRICT q8 = y[ibl].qs;
 
         for (int ib = 0; ib < QK_K/64; ib ++ ) {
             __builtin_prefetch(q4, 0, 1);
@@ -12153,8 +12934,8 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     float sumf = 0;
 
     for (int ibl = 0; ibl < nb; ++ibl) {
-        const uint8_t * restrict q4 = x[ibl].qs;
-        const int8_t  * restrict q8 = y[ibl].qs;
+        const uint8_t * GGML_RESTRICT q4 = x[ibl].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[ibl].qs;
 
         uint16_t h = x[ibl].scales_h;
 
@@ -12234,12 +13015,12 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
 
 // ============================ 4-bit non-linear quants
 
-void quantize_row_iq4_nl(const float * restrict x, void * restrict y, int64_t k) {
+void quantize_row_iq4_nl(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
     assert(k % QK4_NL == 0);
     quantize_row_iq4_nl_ref(x, y, k);
 }
 
-void quantize_row_iq4_xs(const float * restrict x, void * restrict y, int64_t k) {
+void quantize_row_iq4_xs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     quantize_iq4_xs(x, y, 1, k, NULL);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
index ec60e8fcf..6d4abe4c7 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
@@ -9,6 +9,10 @@
 #include "ggml-impl.h"
 #include "ggml-cpu-quants.h"
 #include "ggml-threading.h"
+#include "unary-ops.h"
+#include "binary-ops.h"
+#include "vec.h"
+#include "ops.h"
 #include "ggml.h"
 
 #include "ollama-debug.h"
@@ -83,30 +87,6 @@
 #define UNUSED GGML_UNUSED
 #define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)
 
-#if defined(GGML_USE_ACCELERATE)
-#include 
-#endif
-
-// floating point type used to accumulate sums
-typedef double ggml_float;
-
-#define GGML_GELU_FP16
-#define GGML_GELU_QUICK_FP16
-
-#define GGML_SOFT_MAX_UNROLL 4
-#define GGML_VEC_DOT_UNROLL  2
-#define GGML_VEC_MAD_UNROLL  32
-
-//
-// global data
-//
-
-// precomputed gelu table for f16 (128 KB)
-static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
-
-// precomputed quick gelu table for f16 (128 KB)
-static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
-
 #if defined(__ARM_ARCH)
 struct ggml_arm_arch_features_type {
     int has_neon;
@@ -230,29 +210,6 @@ typedef pthread_t ggml_thread_t;
 #include 
 #endif
 
-//
-// cache line
-//
-
-#if defined(__cpp_lib_hardware_interference_size)
-#define CACHE_LINE_SIZE hardware_destructive_interference_size
-#else
-#if defined(__POWER9_VECTOR__)
-#define CACHE_LINE_SIZE 128
-#elif defined(__VXE__) || defined(__VXE2__)
-#define CACHE_LINE_SIZE 256
-#else
-#define CACHE_LINE_SIZE 64
-#endif
-#endif
-
-static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
-
-
-static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
-static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
-static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
-
 static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32] = {
         .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
@@ -424,887 +381,6 @@ const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type
     return &type_traits_cpu[type];
 }
 
-//
-// simd mappings
-//
-
-// we define a common set of C macros which map to specific intrinsics based on the current architecture
-// we then implement the fundamental computation operations below using only these macros
-// adding support for new architectures requires to define the corresponding SIMD macros
-//
-// GGML_F32_STEP / GGML_F16_STEP
-//   number of elements to process in a single step
-//
-// GGML_F32_EPR / GGML_F16_EPR
-//   number of elements to fit in a single register
-//
-
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
-
-#define GGML_SIMD
-
-// F32 NEON
-
-#define GGML_F32_STEP 16
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4              float32x4_t
-#define GGML_F32x4_ZERO         vdupq_n_f32(0.0f)
-#define GGML_F32x4_SET1(x)      vdupq_n_f32(x)
-#define GGML_F32x4_LOAD         vld1q_f32
-#define GGML_F32x4_STORE        vst1q_f32
-#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
-#define GGML_F32x4_ADD          vaddq_f32
-#define GGML_F32x4_MUL          vmulq_f32
-#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
-#define GGML_F32x4_REDUCE(res, x)                       \
-{                                                       \
-    int offset = GGML_F32_ARR >> 1;                     \
-    for (int i = 0; i < offset; ++i) {                  \
-        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \
-    }                                                   \
-    offset >>= 1;                                       \
-    for (int i = 0; i < offset; ++i) {                  \
-        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \
-    }                                                   \
-    offset >>= 1;                                       \
-    for (int i = 0; i < offset; ++i) {                  \
-        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \
-    }                                                   \
-    (res) = (ggml_float) GGML_F32x4_REDUCE_ONE((x)[0]); \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 NEON
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
-    #define GGML_F16_STEP 32
-    #define GGML_F16_EPR  8
-
-    #define GGML_F16x8              float16x8_t
-    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)
-    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)
-    #define GGML_F16x8_LOAD(x)      vld1q_f16((const ggml_fp16_internal_t *)(x))
-    #define GGML_F16x8_STORE        vst1q_f16
-    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
-    #define GGML_F16x8_ADD          vaddq_f16
-    #define GGML_F16x8_MUL          vmulq_f16
-    #define GGML_F16x8_REDUCE(res, x)                               \
-    do {                                                            \
-        int offset = GGML_F16_ARR >> 1;                             \
-        for (int i = 0; i < offset; ++i) {                          \
-            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
-        }                                                           \
-        offset >>= 1;                                               \
-        for (int i = 0; i < offset; ++i) {                          \
-            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
-        }                                                           \
-        offset >>= 1;                                               \
-        for (int i = 0; i < offset; ++i) {                          \
-            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
-        }                                                           \
-        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
-        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
-        (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
-    } while (0)
-
-    #define GGML_F16_VEC                GGML_F16x8
-    #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
-    #define GGML_F16_VEC_SET1           GGML_F16x8_SET1
-    #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)
-    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i])
-    #define GGML_F16_VEC_FMA            GGML_F16x8_FMA
-    #define GGML_F16_VEC_ADD            GGML_F16x8_ADD
-    #define GGML_F16_VEC_MUL            GGML_F16x8_MUL
-    #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE
-#else
-    // if FP16 vector arithmetic is not supported, we use FP32 instead
-    // and take advantage of the vcvt_ functions to convert to/from FP16
-
-    #define GGML_F16_STEP 16
-    #define GGML_F16_EPR  4
-
-    #define GGML_F32Cx4              float32x4_t
-    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)
-    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)
-    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x)))
-    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))
-    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
-    #define GGML_F32Cx4_ADD          vaddq_f32
-    #define GGML_F32Cx4_MUL          vmulq_f32
-    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE
-
-    #define GGML_F16_VEC                GGML_F32Cx4
-    #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO
-    #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1
-    #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)
-    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
-    #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA
-    #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD
-    #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL
-    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE
-#endif
-
-#elif defined(__AVX512F__)
-
-#define GGML_SIMD
-
-// F32 AVX512
-
-#define GGML_F32_STEP 64
-#define GGML_F32_EPR  16
-
-#define GGML_F32x16         __m512
-#define GGML_F32x16_ZERO    _mm512_setzero_ps()
-#define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
-#define GGML_F32x16_LOAD    _mm512_loadu_ps
-#define GGML_F32x16_STORE   _mm512_storeu_ps
-// _mm512_fmadd_ps is defined in AVX512F so no guard is required
-#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
-#define GGML_F32x16_ADD     _mm512_add_ps
-#define GGML_F32x16_MUL     _mm512_mul_ps
-#define GGML_F32x16_REDUCE(res, x)                                    \
-do {                                                                  \
-    int offset = GGML_F32_ARR >> 1;                                   \
-    for (int i = 0; i < offset; ++i) {                                \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
-    }                                                                 \
-    offset >>= 1;                                                     \
-    for (int i = 0; i < offset; ++i) {                                \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
-    }                                                                 \
-    offset >>= 1;                                                     \
-    for (int i = 0; i < offset; ++i) {                                \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
-    }                                                                 \
-    res = (ggml_float) _mm512_reduce_add_ps(x[0]);                    \
-} while (0)
-
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC        GGML_F32x16
-#define GGML_F32_VEC_ZERO   GGML_F32x16_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x16_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x16_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x16_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x16_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x16_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x16_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
-
-// F16 AVX512
-
-// F16 AVX
-
-#define GGML_F16_STEP 64
-#define GGML_F16_EPR  16
-
-// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
-
-#define GGML_F32Cx16             __m512
-#define GGML_F32Cx16_ZERO        _mm512_setzero_ps()
-#define GGML_F32Cx16_SET1(x)     _mm512_set1_ps(x)
-
-// unlike  _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
-// so F16C guard isn't required
-#define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
-#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
-
-#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
-#define GGML_F32Cx16_ADD         _mm512_add_ps
-#define GGML_F32Cx16_MUL         _mm512_mul_ps
-#define GGML_F32Cx16_REDUCE(res, x)                               \
-do {                                                              \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    res = (ggml_float) _mm512_reduce_add_ps(x[0]);                \
-} while (0)
-
-#define GGML_F16_VEC                GGML_F32Cx16
-#define GGML_F16_VEC_ZERO           GGML_F32Cx16_ZERO
-#define GGML_F16_VEC_SET1           GGML_F32Cx16_SET1
-#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx16_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F32Cx16_FMA
-#define GGML_F16_VEC_ADD            GGML_F32Cx16_ADD
-#define GGML_F16_VEC_MUL            GGML_F32Cx16_MUL
-
-#define GGML_F16_VEC_REDUCE         GGML_F32Cx16_REDUCE
-#elif defined(__AVX__)
-
-#define GGML_SIMD
-
-// F32 AVX
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  8
-
-#define GGML_F32x8         __m256
-#define GGML_F32x8_ZERO    _mm256_setzero_ps()
-#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
-#define GGML_F32x8_LOAD    _mm256_loadu_ps
-#define GGML_F32x8_STORE   _mm256_storeu_ps
-#if defined(__FMA__)
-    #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
-#else
-    #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
-#endif
-#define GGML_F32x8_ADD     _mm256_add_ps
-#define GGML_F32x8_MUL     _mm256_mul_ps
-#define GGML_F32x8_REDUCE(res, x)                                 \
-do {                                                              \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \
-                                 _mm256_extractf128_ps(x[0], 1)); \
-    const __m128 t1 = _mm_hadd_ps(t0, t0);                        \
-    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1));        \
-} while (0)
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC        GGML_F32x8
-#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
-
-// F16 AVX
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR  8
-
-// F16 arithmetic is not supported by AVX, so we use F32 instead
-
-#define GGML_F32Cx8             __m256
-#define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
-#define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
-
-#if defined(__F16C__)
-// the  _mm256_cvt intrinsics require F16C
-#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
-#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
-#else
-static inline __m256 __avx_f32cx8_load(const ggml_fp16_t * x) {
-    float tmp[8];
-
-    for (int i = 0; i < 8; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    return _mm256_loadu_ps(tmp);
-}
-static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
-    float arr[8];
-
-    _mm256_storeu_ps(arr, y);
-
-    for (int i = 0; i < 8; i++)
-        x[i] = GGML_FP32_TO_FP16(arr[i]);
-}
-#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
-#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
-#endif
-
-#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
-#define GGML_F32Cx8_ADD         _mm256_add_ps
-#define GGML_F32Cx8_MUL         _mm256_mul_ps
-#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
-
-#define GGML_F16_VEC                GGML_F32Cx8
-#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
-#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
-#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
-#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
-#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
-#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
-
-#elif defined(__POWER9_VECTOR__)
-
-#define GGML_SIMD
-
-// F32 POWER9
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4              vector float
-#define GGML_F32x4_ZERO         0.0f
-#define GGML_F32x4_SET1         vec_splats
-#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
-#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
-#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
-#define GGML_F32x4_ADD          vec_add
-#define GGML_F32x4_MUL          vec_mul
-#define GGML_F32x4_REDUCE(res, x)              \
-{                                              \
-    int offset = GGML_F32_ARR >> 1;            \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vec_add(x[i], x[offset+i]);     \
-    }                                          \
-    offset >>= 1;                              \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vec_add(x[i], x[offset+i]);     \
-    }                                          \
-    offset >>= 1;                              \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vec_add(x[i], x[offset+i]);     \
-    }                                          \
-    res = vec_extract(x[0], 0) +               \
-          vec_extract(x[0], 1) +               \
-          vec_extract(x[0], 2) +               \
-          vec_extract(x[0], 3);                \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 POWER9
-#define GGML_F16_STEP       GGML_F32_STEP
-#define GGML_F16_EPR        GGML_F32_EPR
-#define GGML_F16_VEC        GGML_F32x4
-#define GGML_F16_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F16_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F16_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F16_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F16_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
-// Use vec_xl, not vec_ld, in case the load address is not aligned.
-#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \
-  vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
-  vec_extract_fp32_from_shortl(vec_xl(0, p))
-#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
-#define GGML_F16_VEC_STORE(p, r, i)                             \
-  if (i & 0x1)                                                  \
-    vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \
-                                   r[i - GGML_ENDIAN_BYTE(0)]), \
-            0, p - GGML_F16_EPR)
-
-#elif defined(__wasm_simd128__)
-
-#define GGML_SIMD
-
-// F32 WASM
-
-#define GGML_F32_STEP 16
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4              v128_t
-#define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f)
-#define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x)
-#define GGML_F32x4_LOAD         wasm_v128_load
-#define GGML_F32x4_STORE        wasm_v128_store
-#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
-#define GGML_F32x4_ADD          wasm_f32x4_add
-#define GGML_F32x4_MUL          wasm_f32x4_mul
-#define GGML_F32x4_REDUCE(res, x)                  \
-{                                                  \
-    int offset = GGML_F32_ARR >> 1;                \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    res = wasm_f32x4_extract_lane(x[0], 0) +       \
-          wasm_f32x4_extract_lane(x[0], 1) +       \
-          wasm_f32x4_extract_lane(x[0], 2) +       \
-          wasm_f32x4_extract_lane(x[0], 3);        \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 WASM
-
-#define GGML_F16_STEP 16
-#define GGML_F16_EPR  4
-
-inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
-    float tmp[4];
-
-    tmp[0] = GGML_FP16_TO_FP32(p[0]);
-    tmp[1] = GGML_FP16_TO_FP32(p[1]);
-    tmp[2] = GGML_FP16_TO_FP32(p[2]);
-    tmp[3] = GGML_FP16_TO_FP32(p[3]);
-
-    return wasm_v128_load(tmp);
-}
-
-inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
-    float tmp[4];
-
-    wasm_v128_store(tmp, x);
-
-    p[0] = GGML_FP32_TO_FP16(tmp[0]);
-    p[1] = GGML_FP32_TO_FP16(tmp[1]);
-    p[2] = GGML_FP32_TO_FP16(tmp[2]);
-    p[3] = GGML_FP32_TO_FP16(tmp[3]);
-}
-
-#define GGML_F16x4             v128_t
-#define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f)
-#define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x)
-#define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x)
-#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
-#define GGML_F16x4_FMA         GGML_F32x4_FMA
-#define GGML_F16x4_ADD         wasm_f32x4_add
-#define GGML_F16x4_MUL         wasm_f32x4_mul
-#define GGML_F16x4_REDUCE(res, x)                  \
-{                                                  \
-    int offset = GGML_F16_ARR >> 1;                \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    res = wasm_f32x4_extract_lane(x[0], 0) +       \
-          wasm_f32x4_extract_lane(x[0], 1) +       \
-          wasm_f32x4_extract_lane(x[0], 2) +       \
-          wasm_f32x4_extract_lane(x[0], 3);        \
-}
-
-#define GGML_F16_VEC                GGML_F16x4
-#define GGML_F16_VEC_ZERO           GGML_F16x4_ZERO
-#define GGML_F16_VEC_SET1           GGML_F16x4_SET1
-#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F16x4_FMA
-#define GGML_F16_VEC_ADD            GGML_F16x4_ADD
-#define GGML_F16_VEC_MUL            GGML_F16x4_MUL
-#define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE
-
-#elif defined(__SSE3__)
-
-#define GGML_SIMD
-
-// F32 SSE
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4         __m128
-#define GGML_F32x4_ZERO    _mm_setzero_ps()
-#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
-#define GGML_F32x4_LOAD    _mm_loadu_ps
-#define GGML_F32x4_STORE   _mm_storeu_ps
-#if defined(__FMA__)
-    // TODO: Does this work?
-    #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
-#else
-    #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
-#endif
-#define GGML_F32x4_ADD     _mm_add_ps
-#define GGML_F32x4_MUL     _mm_mul_ps
-#define GGML_F32x4_REDUCE(res, x)                                 \
-{                                                                 \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
-    }                                                             \
-    const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \
-    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0));        \
-}
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 SSE
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR  4
-
-static inline __m128 __sse_f16x4_load(const ggml_fp16_t * x) {
-    float tmp[4];
-
-    tmp[0] = GGML_FP16_TO_FP32(x[0]);
-    tmp[1] = GGML_FP16_TO_FP32(x[1]);
-    tmp[2] = GGML_FP16_TO_FP32(x[2]);
-    tmp[3] = GGML_FP16_TO_FP32(x[3]);
-
-    return _mm_loadu_ps(tmp);
-}
-
-static inline void __sse_f16x4_store(ggml_fp16_t * x, __m128 y) {
-    float arr[4];
-
-    _mm_storeu_ps(arr, y);
-
-    x[0] = GGML_FP32_TO_FP16(arr[0]);
-    x[1] = GGML_FP32_TO_FP16(arr[1]);
-    x[2] = GGML_FP32_TO_FP16(arr[2]);
-    x[3] = GGML_FP32_TO_FP16(arr[3]);
-}
-
-#define GGML_F32Cx4             __m128
-#define GGML_F32Cx4_ZERO        _mm_setzero_ps()
-#define GGML_F32Cx4_SET1(x)     _mm_set1_ps(x)
-#define GGML_F32Cx4_LOAD(x)     __sse_f16x4_load(x)
-#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
-#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
-#define GGML_F32Cx4_ADD         _mm_add_ps
-#define GGML_F32Cx4_MUL         _mm_mul_ps
-#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
-
-#define GGML_F16_VEC                 GGML_F32Cx4
-#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
-#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
-#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
-#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
-#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
-#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
-
-#elif defined(__loongarch_asx)
-
-#define GGML_SIMD
-
-// F32 LASX
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  8
-
-#define GGML_F32x8         __m256
-#define GGML_F32x8_ZERO    (__m256)__lasx_xvldi(0)
-#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
-#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
-#define GGML_F32x8_STORE(x,y)   __lasx_xvst((y), (x), 0)
-#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
-#define GGML_F32x8_ADD     __lasx_xvfadd_s
-#define GGML_F32x8_MUL     __lasx_xvfmul_s
-#define GGML_F32x8_REDUCE(res, x)                                 \
-do {                                                              \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
-    }                                                             \
-    float *tmp_p = (float *)&x[0]; \
-    res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7];  \
-} while (0)
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC        GGML_F32x8
-#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
-
-// F16 LASX
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR  8
-
-// F16 arithmetic is not supported by LASX, so we use F32 instead
-
-#define GGML_F32Cx8          __m256
-#define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0)
-#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
-
-static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
-    __m256i a;
-    memcpy(&a, x, sizeof(ggml_fp16_t) * 8);
-    a = __lasx_xvpermi_d(a, 0 | (1 << 4));
-    return __lasx_xvfcvtl_s_h(a);
-}
-
-static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
-    __m256i a = __lasx_xvfcvt_h_s(y, y);
-    a = __lasx_xvpermi_d(a, 0 | (2 << 2));
-    memcpy(x, &a, sizeof(ggml_fp16_t) * 8);
-}
-#define GGML_F32Cx8_LOAD(x)     __lasx_f32cx8_load(x)
-#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
-
-#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
-#define GGML_F32Cx8_ADD         __lasx_xvfadd_s
-#define GGML_F32Cx8_MUL         __lasx_xvfmul_s
-#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
-
-#define GGML_F16_VEC                GGML_F32Cx8
-#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
-#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
-#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
-#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
-#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
-#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
-
-#elif defined(__loongarch_sx)
-
-#define GGML_SIMD
-
-// F32 LSX
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4         __m128
-#define GGML_F32x4_ZERO    __lsx_vldi(0)
-#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
-#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
-#define GGML_F32x4_STORE((x),(y))   __lsx_vst((y), (x), 0)
-#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
-#define GGML_F32x4_ADD     __lsx_vfadd_s
-#define GGML_F32x4_MUL     __lsx_vfmul_s
-#define GGML_F32x4_REDUCE(res, x)                                                     \
-{                                                                                     \
-    int offset = GGML_F32_ARR >> 1;                                                   \
-    for (int i = 0; i < offset; ++i) {                                                \
-        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
-    }                                                                                 \
-    offset >>= 1;                                                                     \
-    for (int i = 0; i < offset; ++i) {                                                \
-        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
-    }                                                                                 \
-    offset >>= 1;                                                                     \
-    for (int i = 0; i < offset; ++i) {                                                \
-        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
-    }                                                                                 \
-    __m128i tmp     = __lsx_vsrli_d((__m128i) x[0], 32);                              \
-    tmp             = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]);                    \
-    tmp             = __lsx_vpickev_w(__lsx_vldi(0), tmp);                            \
-    const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88);                                     \
-    tmp             = __lsx_vsrli_d((__m128i) t0, 32);                                \
-    tmp             = (__m128i) __lsx_vfadd_s((__m128) tmp, t0);                      \
-    tmp             = __lsx_vpickev_w(__lsx_vldi(0), tmp);                            \
-    res             = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 LSX
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR  4
-
-static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
-    float tmp[4];
-
-    tmp[0] = GGML_FP16_TO_FP32(x[0]);
-    tmp[1] = GGML_FP16_TO_FP32(x[1]);
-    tmp[2] = GGML_FP16_TO_FP32(x[2]);
-    tmp[3] = GGML_FP16_TO_FP32(x[3]);
-
-    return __lsx_vld(tmp, 0);
-}
-
-static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
-    float arr[4];
-
-    __lsx_vst(y, arr, 0);
-
-    x[0] = GGML_FP32_TO_FP16(arr[0]);
-    x[1] = GGML_FP32_TO_FP16(arr[1]);
-    x[2] = GGML_FP32_TO_FP16(arr[2]);
-    x[3] = GGML_FP32_TO_FP16(arr[3]);
-}
-
-#define GGML_F32Cx4             __m128
-#define GGML_F32Cx4_ZERO        __lsx_vldi(0)
-#define GGML_F32Cx4_SET1(x)     __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
-#define GGML_F32Cx4_LOAD(x)     __lsx_f16x4_load(x)
-#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
-#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
-#define GGML_F32Cx4_ADD         __lsx_vfadd_s
-#define GGML_F32Cx4_MUL         __lsx_vfmul_s
-#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
-
-#define GGML_F16_VEC                 GGML_F32Cx4
-#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
-#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
-#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
-#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
-#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
-#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
-
-#elif defined(__VXE__) || defined(__VXE2__)
-
-#define GGML_SIMD
-
-// F32 s390x
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4              __vector float
-#define GGML_F32x4_ZERO         vec_splats(0.0f)
-#define GGML_F32x4_SET1         vec_splats
-#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
-#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
-#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
-#define GGML_F32x4_ADD          vec_add
-#define GGML_F32x4_MUL          vec_mul
-#define GGML_F32x4_REDUCE(res, x)                   \
-{                                                   \
-    int offset = GGML_F32_ARR >> 1;                 \
-    for (int i = 0; i < offset; ++i) {              \
-        x[i] = vec_add(x[i], x[offset + i]);        \
-    }                                               \
-    offset >>= 1;                                   \
-    for (int i = 0; i < offset; ++i) {              \
-        x[i] = vec_add(x[i], x[offset + i]);        \
-    }                                               \
-    offset >>= 1;                                   \
-    for (int i = 0; i < offset; ++i) {              \
-        x[i] = vec_add(x[i], x[offset + i]);        \
-    }                                               \
-    res = vec_extract(x[0], 0) +                    \
-          vec_extract(x[0], 1) +                    \
-          vec_extract(x[0], 2) +                    \
-          vec_extract(x[0], 3);                     \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 s390x
-#define GGML_F16_STEP GGML_F32_STEP
-#define GGML_F16_EPR  GGML_F32_EPR
-
-static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) {
-    float tmp[4];
-
-    for (int i = 0; i < 4; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    return vec_xl(0, tmp);
-}
-
-static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) {
-    float arr[4];
-
-    vec_xst(y, 0, arr);
-
-    for (int i = 0; i < 4; i++) {
-        x[i] = GGML_FP32_TO_FP16(arr[i]);
-    }
-}
-
-#define GGML_F16_VEC                GGML_F32x4
-#define GGML_F16_VEC_ZERO           GGML_F32x4_ZERO
-#define GGML_F16_VEC_SET1           GGML_F32x4_SET1
-#define GGML_F16_VEC_LOAD(p, i)     __lzs_f16cx4_load(p)
-#define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F32x4_FMA
-#define GGML_F16_VEC_ADD            GGML_F32x4_ADD
-#define GGML_F16_VEC_MUL            GGML_F32x4_MUL
-#define GGML_F16_VEC_REDUCE         GGML_F32x4_REDUCE
-
-#endif
-
-// GGML_F32_ARR / GGML_F16_ARR
-//   number of registers to use per step
-#ifdef GGML_SIMD
-#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
-#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
-#endif
-
 //
 // Threading defs
 //
@@ -1404,864 +480,6 @@ struct ggml_compute_state {
     int ith;
 };
 
-//
-// fundamental operations
-//
-
-inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t   v) { for (int i = 0; i < n; ++i) x[i] = v;    }
-inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
-
-inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
-inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    }
-inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
-inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
-inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
-inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
-inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
-inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
-inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
-inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
-
-static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
-   assert(nrc == 1);
-   UNUSED(nrc);
-   UNUSED(bx);
-   UNUSED(by);
-   UNUSED(bs);
-
-#if defined(GGML_SIMD)
-    float sumf = 0.0f;
-    const int np = (n & ~(GGML_F32_STEP - 1));
-
-    GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
-
-    GGML_F32_VEC ax[GGML_F32_ARR];
-    GGML_F32_VEC ay[GGML_F32_ARR];
-
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-
-            sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
-        }
-    }
-
-    // reduce sum0..sum3 to sum0
-    GGML_F32_VEC_REDUCE(sumf, sum);
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        sumf += x[i]*y[i];
-    }
-#else
-    // scalar
-    ggml_float sumf = 0.0;
-    for (int i = 0; i < n; ++i) {
-        sumf += (ggml_float)(x[i]*y[i]);
-    }
-#endif
-
-    *s = sumf;
-}
-
-static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
-    assert(nrc == 1);
-    UNUSED(nrc);
-    UNUSED(bx);
-    UNUSED(by);
-    UNUSED(bs);
-    int i = 0;
-    ggml_float sumf = 0;
-
-#if defined(__AVX512BF16__)
-    __m512 c1 = _mm512_setzero_ps();
-    __m512 c2 = _mm512_setzero_ps();
-    for (; i + 64 <= n; i += 64) {
-        c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
-                             m512bh(_mm512_loadu_si512((y + i))));
-        c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
-                             m512bh(_mm512_loadu_si512((y + i + 32))));
-    }
-    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
-    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
-
-#elif defined(__AVX512F__)
-#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
-    __m512 c1 = _mm512_setzero_ps();
-    __m512 c2 = _mm512_setzero_ps();
-    for (; i + 32 <= n; i += 32) {
-        c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
-        c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
-    }
-    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
-    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
-
-#undef LOAD
-#elif defined(__AVX2__) || defined(__AVX__)
-#if defined(__AVX2__)
-#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
-#else
-#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
-#endif
-    __m256 c1 = _mm256_setzero_ps();
-    __m256 c2 = _mm256_setzero_ps();
-    __m256 c3 = _mm256_setzero_ps();
-    __m256 c4 = _mm256_setzero_ps();
-    for (; i + 32 <= n; i += 32) {
-        c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
-        c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
-        c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
-        c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
-    }
-    __m128 g;
-    c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
-                       _mm256_add_ps(c2, c4));
-    g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
-                   _mm256_castps256_ps128(c1));
-    g = _mm_add_ps(g, _mm_movehl_ps(g, g));
-    g = _mm_add_ss(g, _mm_movehdup_ps(g));
-    sumf += (ggml_float)_mm_cvtss_f32(g);
-
-#undef LOAD
-#endif
-
-    for (; i < n; ++i) {
-        sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
-                             GGML_BF16_TO_FP32(y[i]));
-    }
-    *s = sumf;
-}
-
-static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
-    assert(nrc == 1);
-    UNUSED(nrc);
-    UNUSED(bx);
-    UNUSED(by);
-    UNUSED(bs);
-
-    ggml_float sumf = 0.0;
-
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
-
-    GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
-
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
-
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-
-            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
-        }
-    }
-
-    // reduce sum0..sum3 to sum0
-    GGML_F16_VEC_REDUCE(sumf, sum);
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
-    }
-#else
-    for (int i = 0; i < n; ++i) {
-        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
-    }
-#endif
-
-    *s = sumf;
-}
-
-// compute GGML_VEC_DOT_UNROLL dot products at once
-// xs - x row stride in bytes
-inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
-    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
-
-    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
-
-    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
-        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
-    }
-
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
-
-    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
-
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
-
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-
-            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
-                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
-
-                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
-            }
-        }
-    }
-
-    // reduce sum0..sum3 to sum0
-    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
-        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
-        }
-    }
-#else
-    for (int i = 0; i < n; ++i) {
-        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
-        }
-    }
-#endif
-
-    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
-        s[i] = sumf[i];
-    }
-}
-
-inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F32_STEP - 1));
-
-    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
-
-    GGML_F32_VEC ax[GGML_F32_ARR];
-    GGML_F32_VEC ay[GGML_F32_ARR];
-
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
-
-            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
-        }
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        y[i] += x[i]*v;
-    }
-#else
-    // scalar
-    for (int i = 0; i < n; ++i) {
-        y[i] += x[i]*v;
-    }
-#endif
-}
-
-inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
-
-    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
-
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
-
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
-
-            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
-        }
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
-    }
-#else
-    // scalar
-    for (int i = 0; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
-    }
-#endif
-}
-
-// xs and vs are byte strides of x and v
-inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
-
-    const float * restrict x[GGML_VEC_MAD_UNROLL];
-    const float * restrict v[GGML_VEC_MAD_UNROLL];
-
-    for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
-        x[i] = (const float *) ((const char *) xv + i*xs);
-        v[i] = (const float *) ((const char *) vv + i*vs);
-    }
-
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F32_STEP - 1));
-
-    GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
-
-    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
-        vx[k] = GGML_F32_VEC_SET1(v[k][0]);
-    }
-
-    GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
-    GGML_F32_VEC ay[GGML_F32_ARR];
-
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-
-            for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
-                ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
-                ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
-            }
-
-            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
-        }
-    }
-
-    // leftovers
-    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
-        for (int i = np; i < n; ++i) {
-            y[i] += x[k][i]*v[k][0];
-        }
-    }
-#else
-    // scalar
-    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
-        for (int i = 0; i < n; ++i) {
-            y[i] += x[k][i]*v[k][0];
-        }
-    }
-#endif
-}
-
-//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
-inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
-#if defined(GGML_USE_ACCELERATE)
-    vDSP_vsmul(y, 1, &v, y, 1, n);
-#elif defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F32_STEP - 1));
-
-    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
-
-    GGML_F32_VEC ay[GGML_F32_ARR];
-
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
-
-            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
-        }
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        y[i] *= v;
-    }
-#else
-    // scalar
-    for (int i = 0; i < n; ++i) {
-        y[i] *= v;
-    }
-#endif
-}
-
-inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
-
-    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
-
-    GGML_F16_VEC ay[GGML_F16_ARR];
-
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
-
-            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
-        }
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
-    }
-#else
-    // scalar
-    for (int i = 0; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
-    }
-#endif
-}
-
-inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
-inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
-inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
-inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  }
-inline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  }
-inline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  }
-inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
-inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
-inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
-inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
-inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
-inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
-inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
-inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
-// TODO: optimize performance
-inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
-inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
-inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
-
-static const float GELU_COEF_A     = 0.044715f;
-static const float GELU_QUICK_COEF = -1.702f;
-static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
-
-inline static float ggml_gelu_f32(float x) {
-    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
-    const uint16_t * i16 = (const uint16_t *) x;
-    for (int i = 0; i < n; ++i) {
-        y[i] = ggml_table_gelu_f16[i16[i]];
-    }
-}
-
-#ifdef GGML_GELU_FP16
-inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
-    uint16_t t;
-    for (int i = 0; i < n; ++i) {
-        if (x[i] <= -10.0f) {
-            y[i] = 0.0f;
-        } else if (x[i] >= 10.0f) {
-            y[i] = x[i];
-        } else {
-            ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
-            memcpy(&t, &fp16, sizeof(uint16_t));
-            y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
-        }
-    }
-}
-#else
-inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
-    for (int i = 0; i < n; ++i) {
-        y[i] = ggml_gelu_f32(x[i]);
-    }
-}
-#endif
-
-inline static float ggml_gelu_quick_f32(float x) {
-    return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
-}
-
-//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
-//    const uint16_t * i16 = (const uint16_t *) x;
-//    for (int i = 0; i < n; ++i) {
-//        y[i] = ggml_table_gelu_quick_f16[i16[i]];
-//    }
-//}
-
-#ifdef GGML_GELU_QUICK_FP16
-inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
-    uint16_t t;
-    for (int i = 0; i < n; ++i) {
-        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
-        memcpy(&t, &fp16, sizeof(uint16_t));
-        y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
-    }
-}
-#else
-inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
-    for (int i = 0; i < n; ++i) {
-        y[i] = ggml_gelu_quick_f32(x[i]);
-    }
-}
-#endif
-
-// Sigmoid Linear Unit (SiLU) function
-inline static float ggml_silu_f32(float x) {
-    return x/(1.0f + expf(-x));
-}
-
-#if __FINITE_MATH_ONLY__
-#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
-#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
-#endif
-
-#if defined(__ARM_NEON) && defined(__aarch64__)
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static float32x4_t ggml_v_expf(float32x4_t x) {
-    const float32x4_t r = vdupq_n_f32(0x1.8p23f);
-    const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
-    const float32x4_t n = vsubq_f32(z, r);
-    const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
-                                    vdupq_n_f32(0x1.7f7d1cp-20f));
-    const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
-    const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
-    const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
-    const float32x4_t u = vmulq_f32(b, b);
-    const float32x4_t j = vfmaq_f32(
-        vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
-        vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
-                  vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
-    if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
-        return vfmaq_f32(k, j, k);
-    const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
-    const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
-    const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
-    return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
-                     vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static float32x4_t ggml_v_silu(float32x4_t x) {
-    const float32x4_t one = vdupq_n_f32(1.0f);
-    const float32x4_t zero = vdupq_n_f32(0.0f);
-    const float32x4_t neg_x = vsubq_f32(zero, x);
-    const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
-    const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
-    return vdivq_f32(x, one_plus_exp_neg_x);
-}
-
-#elif defined(__AVX512F__) && defined(__AVX512DQ__)
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static __m512 ggml_v_expf(__m512 x) {
-  const __m512 r = _mm512_set1_ps(0x1.8p23f);
-  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
-  const __m512 n = _mm512_sub_ps(z, r);
-  const __m512 b =
-      _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
-                       _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
-  const __mmask16 d =
-      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
-  const __m512 u = _mm512_mul_ps(b, b);
-  const __m512 j = _mm512_fmadd_ps(
-      _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
-                                      _mm512_set1_ps(0x1.573e2ep-5f)),
-                      u,
-                      _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
-                                      _mm512_set1_ps(0x1.fffdb6p-2f))),
-      u,
-      _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
-  const __m512 res = _mm512_scalef_ps(j, n);
-  if (_mm512_kortestz(d, d))
-    return res;
-  const __m512 zero = _mm512_setzero_ps();
-  const __m512 alt = _mm512_mask_blend_ps(
-      _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
-  return _mm512_mask_blend_ps(d, res, alt);
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static __m512 ggml_v_silu(__m512 x) {
-    const __m512 one = _mm512_set1_ps(1);
-    const __m512 zero = _mm512_setzero_ps();
-    const __m512 neg_x = _mm512_sub_ps(zero, x);
-    const __m512 exp_neg_x = ggml_v_expf(neg_x);
-    const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
-    return _mm512_div_ps(x, one_plus_exp_neg_x);
-}
-
-#elif defined(__AVX2__) && defined(__FMA__)
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static __m256 ggml_v_expf(__m256 x) {
-  const __m256 r = _mm256_set1_ps(0x1.8p23f);
-  const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
-  const __m256 n = _mm256_sub_ps(z, r);
-  const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
-                                    _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
-  const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
-  const __m256 k = _mm256_castsi256_ps(
-      _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
-  const __m256i c = _mm256_castps_si256(
-      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
-                    _mm256_set1_ps(126), _CMP_GT_OQ));
-  const __m256 u = _mm256_mul_ps(b, b);
-  const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
-                                                                   _mm256_set1_ps(0x1.573e2ep-5f)), u,
-                                                   _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
-                                                                   _mm256_set1_ps(0x1.fffdb6p-2f))),
-                                   u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
-  if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
-    return _mm256_fmadd_ps(j, k, k);
-  const __m256i g = _mm256_and_si256(
-      _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
-      _mm256_set1_epi32(0x82000000u));
-  const __m256 s1 =
-      _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
-  const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
-  const __m256i d = _mm256_castps_si256(
-      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
-                    _mm256_set1_ps(192), _CMP_GT_OQ));
-  return _mm256_or_ps(
-      _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
-      _mm256_andnot_ps(
-          _mm256_castsi256_ps(d),
-          _mm256_or_ps(
-              _mm256_and_ps(_mm256_castsi256_ps(c),
-                            _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
-              _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static __m256 ggml_v_silu(__m256 x) {
-    const __m256 one = _mm256_set1_ps(1);
-    const __m256 zero = _mm256_setzero_ps();
-    const __m256 neg_x = _mm256_sub_ps(zero, x);
-    const __m256 exp_neg_x = ggml_v_expf(neg_x);
-    const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
-    return _mm256_div_ps(x, one_plus_exp_neg_x);
-}
-
-#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
-
-#if defined(__FMA__)
-#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
-#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
-#else
-#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
-#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
-#endif
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static __m128 ggml_v_expf(__m128 x) {
-    const __m128 r = _mm_set1_ps(0x1.8p23f);
-    const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
-    const __m128 n = _mm_sub_ps(z, r);
-    const __m128 b =
-        NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
-    const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
-    const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
-    const __m128i c =
-        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
-    const __m128 u = _mm_mul_ps(b, b);
-    const __m128 j =
-        MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
-                        MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
-                u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
-    if (!_mm_movemask_epi8(c))
-        return MADD128(j, k, k);
-    const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
-                                    _mm_set1_epi32(0x82000000u));
-    const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
-    const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
-    const __m128i d =
-        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
-    return _mm_or_ps(
-        _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
-        _mm_andnot_ps(_mm_castsi128_ps(d),
-                      _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
-                                _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static __m128 ggml_v_silu(__m128 x) {
-    const __m128 one = _mm_set1_ps(1);
-    const __m128 zero = _mm_setzero_ps();
-    const __m128 neg_x = _mm_sub_ps(zero, x);
-    const __m128 exp_neg_x = ggml_v_expf(neg_x);
-    const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
-    return _mm_div_ps(x, one_plus_exp_neg_x);
-}
-
-#endif // __ARM_NEON / __AVX2__ / __SSE2__
-
-static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
-    int i = 0;
-#if defined(__AVX512F__) && defined(__AVX512DQ__)
-    for (; i + 15 < n; i += 16) {
-        _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
-    }
-#elif defined(__AVX2__) && defined(__FMA__)
-    for (; i + 7 < n; i += 8) {
-        _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
-    }
-#elif defined(__SSE2__)
-    for (; i + 3 < n; i += 4) {
-        _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
-    }
-#elif defined(__ARM_NEON) && defined(__aarch64__)
-    for (; i + 3 < n; i += 4) {
-        vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
-    }
-#endif
-    for (; i < n; ++i) {
-        y[i] = ggml_silu_f32(x[i]);
-    }
-}
-
-static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
-    int i = 0;
-    ggml_float sum = 0;
-#if defined(__AVX512F__) && defined(__AVX512DQ__)
-    for (; i + 15 < n; i += 16) {
-        __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
-                                               _mm512_set1_ps(max)));
-        _mm512_storeu_ps(y + i, val);
-        sum += (ggml_float)_mm512_reduce_add_ps(val);
-    }
-#elif defined(__AVX2__) && defined(__FMA__)
-    for (; i + 7 < n; i += 8) {
-        __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
-                                               _mm256_set1_ps(max)));
-        _mm256_storeu_ps(y + i, val);
-        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
-                                 _mm256_castps256_ps128(val));
-        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
-        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
-        sum += (ggml_float)_mm_cvtss_f32(val2);
-    }
-#elif defined(__SSE2__)
-    for (; i + 3 < n; i += 4) {
-        __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
-                                            _mm_set1_ps(max)));
-        _mm_storeu_ps(y + i, val);
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
-        val = _mm_add_ss(val, _mm_movehdup_ps(val));
-#else
-        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
-        val = _mm_add_ps(val, tmp);
-        tmp = _mm_movehl_ps(tmp, val);
-        val = _mm_add_ss(val, tmp);
-#endif
-        sum += (ggml_float)_mm_cvtss_f32(val);
-    }
-#elif defined(__ARM_NEON) && defined(__aarch64__)
-    for (; i + 3 < n; i += 4) {
-        float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
-                                                vdupq_n_f32(max)));
-        vst1q_f32(y + i, val);
-        sum += (ggml_float)vaddvq_f32(val);
-    }
-#endif
-    for (; i < n; ++i) {
-        float val = expf(x[i] - max);
-        sum += (ggml_float)val;
-        y[i] = val;
-    }
-    return sum;
-}
-
-static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
-    // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
-
-    int i = 0;
-    ggml_float sum = 0;
-    for (; i < n; ++i) {
-        float val = x[i] - max;
-        y[i] = val;
-        sum += (ggml_float)expf(val);
-    }
-    return sum = (ggml_float)logf(sum);
-}
-
-inline static float ggml_silu_backward_f32(float x, float dy) {
-    const float s = 1.0f/(1.0f + expf(-x));
-    return dy*s*(1.0f + x*(1.0f - s));
-}
-
-inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
-    for (int i = 0; i < n; ++i) {
-        dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
-    }
-}
-
-inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
-#ifndef GGML_USE_ACCELERATE
-    ggml_float sum = 0.0;
-    for (int i = 0; i < n; ++i) {
-        sum += (ggml_float)x[i];
-    }
-    *s = sum;
-#else
-    vDSP_sve(x, 1, s, n);
-#endif
-}
-
-inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
-    ggml_float sum = 0.0;
-    for (int i = 0; i < n; ++i) {
-        sum += (ggml_float)x[i];
-    }
-    *s = sum;
-}
-
-inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
-    float sum = 0.0f;
-    for (int i = 0; i < n; ++i) {
-        sum += GGML_FP16_TO_FP32(x[i]);
-    }
-    *s = sum;
-}
-
-inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
-    float sum = 0.0f;
-    for (int i = 0; i < n; ++i) {
-        sum += GGML_BF16_TO_FP32(x[i]);
-    }
-    *s = sum;
-}
-
-inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
-#ifndef GGML_USE_ACCELERATE
-    float max = -INFINITY;
-    for (int i = 0; i < n; ++i) {
-        max = MAX(max, x[i]);
-    }
-    *s = max;
-#else
-    vDSP_maxv(x, 1, s, n);
-#endif
-}
-
-inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
-    ggml_vec_norm_f32(n, s, x);
-    *s = 1.f/(*s);
-}
-
-inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
-    float max = -INFINITY;
-    int idx = 0;
-    for (int i = 0; i < n; ++i) {
-        max = MAX(max, x[i]);
-        if (max == x[i]) { idx = i; }
-    }
-    *s = idx;
-}
-
 // Helpers for polling loops
 #if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) )
 static inline void ggml_thread_cpu_relax(void) {
@@ -2955,4484 +1173,6 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
 
 ////////////////////////////////////////////////////////////////////////////////
 
-// ggml_compute_forward_dup
-
-static void ggml_compute_forward_dup_same_cont(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-    GGML_ASSERT(src0->type == dst->type);
-
-    const size_t nb0 = ggml_type_size(src0->type);
-
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-    // parallelize by elements
-    const int ne = ggml_nelements(dst);
-    const int dr = (ne + nth - 1) / nth;
-    const int ie0 = dr * ith;
-    const int ie1 = MIN(ie0 + dr, ne);
-
-    if (ie0 < ie1) {
-        memcpy(
-            ((char *)  dst->data + ie0*nb0),
-            ((char *) src0->data + ie0*nb0),
-            (ie1 - ie0) * nb0);
-    }
-}
-
-static void ggml_compute_forward_dup_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-    // parallelize by rows
-    const int nr = ne01;
-    // number of rows per thread
-    const int dr = (nr + nth - 1) / nth;
-    // row range for this thread
-    const int ir0 = dr * ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (src0->type == dst->type &&
-        ne00 == ne0 &&
-        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
-        // copy by rows
-        const size_t rs = ne00*nb00;
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    memcpy(
-                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
-                        rs);
-                }
-            }
-        }
-        return;
-    }
-
-    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
-
-    if (ggml_is_contiguous(dst)) {
-        if (nb00 == sizeof(ggml_fp16_t)) {
-            if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                const size_t rs = ne00 * nb00;
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, rs);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
-                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
-                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
-                size_t id = 0;
-                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
-                            }
-
-                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        } else {
-            //printf("%s: this is not optimal - fix me\n", __func__);
-
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = *src0_ptr;
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        }
-        return;
-    }
-
-    // dst counters
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
-
-    if (dst->type == GGML_TYPE_F16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
-
-                        if (++i10 == ne00) {
-                            i10 = 0;
-                            if (++i11 == ne01) {
-                                i11 = 0;
-                                if (++i12 == ne02) {
-                                    i12 = 0;
-                                    if (++i13 == ne03) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F32) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else {
-        GGML_ABORT("fatal error"); // TODO: implement
-    }
-}
-
-static void ggml_compute_forward_dup_bf16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-    // parallelize by rows
-    const int nr = ne01;
-    // number of rows per thread
-    const int dr = (nr + nth - 1) / nth;
-    // row range for this thread
-    const int ir0 = dr * ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (src0->type == dst->type &&
-        ne00 == ne0 &&
-        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
-        // copy by rows
-        const size_t rs = ne00*nb00;
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    memcpy(
-                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
-                        rs);
-                }
-            }
-        }
-        return;
-    }
-
-    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
-
-    if (ggml_is_contiguous(dst)) {
-        if (nb00 == sizeof(ggml_bf16_t)) {
-            if (dst->type == GGML_TYPE_BF16) {
-                size_t id = 0;
-                const size_t rs = ne00 * nb00;
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, rs);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
-                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
-                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
-                size_t id = 0;
-                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
-                            }
-
-                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        } else {
-            //printf("%s: this is not optimal - fix me\n", __func__);
-
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_BF16) {
-                size_t id = 0;
-                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = *src0_ptr;
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        }
-        return;
-    }
-
-    // dst counters
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
-
-    if (dst->type == GGML_TYPE_BF16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
-
-                        if (++i10 == ne00) {
-                            i10 = 0;
-                            if (++i11 == ne01) {
-                                i11 = 0;
-                                if (++i12 == ne02) {
-                                    i12 = 0;
-                                    if (++i13 == ne03) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F32) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else {
-        GGML_ABORT("fatal error"); // TODO: implement
-    }
-}
-
-static void ggml_compute_forward_dup_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-    // parallelize by rows
-    const int nr = ne01;
-    // number of rows per thread
-    const int dr = (nr + nth - 1) / nth;
-    // row range for this thread
-    const int ir0 = dr * ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (src0->type == dst->type &&
-        ne00 == ne0 &&
-        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
-        // copy by rows
-        const size_t rs = ne00*nb00;
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    memcpy(
-                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
-                        rs);
-                }
-            }
-        }
-        return;
-    }
-
-    if (ggml_is_contiguous(dst)) {
-        // TODO: simplify
-        if (nb00 == sizeof(float)) {
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                const size_t rs = ne00 * nb00;
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, rs);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
-                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
-
-                size_t id = 0;
-                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            quantize_row_q(src0_ptr, dst_ptr + id, ne00);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        } else {
-            //printf("%s: this is not optimal - fix me\n", __func__);
-
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = *src0_ptr;
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_BF16) {
-                size_t id = 0;
-                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        }
-
-        return;
-    }
-
-    // dst counters
-
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
-
-    if (dst->type == GGML_TYPE_F32) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        memcpy(dst_ptr, src0_ptr, sizeof(float));
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_BF16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else {
-        GGML_ABORT("fatal error"); // TODO: implement
-    }
-}
-
-// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
-static void ggml_compute_forward_dup_bytes(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-    GGML_ASSERT(src0->type == dst->type);
-
-    GGML_TENSOR_UNARY_OP_LOCALS;
-
-    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
-        ggml_compute_forward_dup_same_cont(params, dst);
-        return;
-    }
-
-    const size_t type_size = ggml_type_size(src0->type);
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-
-    // parallelize by rows
-    const int nr = ne01;
-    // number of rows per thread
-    const int dr = (nr + nth - 1) / nth;
-    // row range for this thread
-    const int ir0 = dr * ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (src0->type == dst->type &&
-        ne00 == ne0 &&
-        nb00 == type_size && nb0 == type_size) {
-        // copy by rows
-        const size_t rs = ne00 * type_size;
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    memcpy(
-                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
-                        rs);
-                }
-            }
-        }
-        return;
-    }
-
-    if (ggml_is_contiguous(dst)) {
-        size_t id = 0;
-        char * dst_ptr = (char *) dst->data;
-        const size_t rs = ne00 * type_size;
-
-        if (nb00 == type_size) {
-            // src0 is contigous on first dimension, copy by rows
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    id += rs * ir0;
-                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                        memcpy(dst_ptr + id, src0_ptr, rs);
-                        id += rs;
-                    }
-                    id += rs * (ne01 - ir1);
-                }
-            }
-        } else {
-            //printf("%s: this is not optimal - fix me\n", __func__);
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    id += rs * ir0;
-                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                        for (int64_t i00 = 0; i00 < ne00; i00++) {
-                            const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, type_size);
-
-                            id += type_size;
-                        }
-                    }
-                    id += rs * (ne01 - ir1);
-                }
-            }
-        }
-
-        return;
-    }
-
-    // dst counters
-
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            i10 += ne00 * ir0;
-            while (i10 >= ne0) {
-                i10 -= ne0;
-                if (++i11 == ne1) {
-                    i11 = 0;
-                    if (++i12 == ne2) {
-                        i12 = 0;
-                        if (++i13 == ne3) {
-                            i13 = 0;
-                        }
-                    }
-                }
-            }
-            for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                          char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                    memcpy(dst_ptr, src0_ptr, type_size);
-
-                    if (++i10 == ne0) {
-                        i10 = 0;
-                        if (++i11 == ne1) {
-                            i11 = 0;
-                            if (++i12 == ne2) {
-                                i12 = 0;
-                                if (++i13 == ne3) {
-                                    i13 = 0;
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-            i10 += ne00 * (ne01 - ir1);
-            while (i10 >= ne0) {
-                i10 -= ne0;
-                if (++i11 == ne1) {
-                    i11 = 0;
-                    if (++i12 == ne2) {
-                        i12 = 0;
-                        if (++i13 == ne3) {
-                            i13 = 0;
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_dup_q(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const enum ggml_type type = src0->type;
-    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-
-    size_t qk = ggml_blck_size(type);
-    const int64_t nr = ggml_nelements(src1) / qk;
-
-    // destination must be contiguous in the first dimension
-    GGML_ASSERT(nb10 == ggml_type_size(dst->type));
-    // must either have first dimension large enough to hold a row, or fully contiguous
-    GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t ir = ir0; ir < ir1; ++ir) {
-
-        uint32_t i = ir * qk;
-
-        const int64_t i03 = i/(ne00 * ne01 * ne02);
-        const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-        const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-        const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-        const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
-
-        const int64_t i13 = i/(ne10 * ne11 * ne12);
-        const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-        const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-        const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-        const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
-
-        dequantize_row_q(
-                (const void *) ((char *) src0->data + x_offset),
-                     (float *) ((char *)  dst->data + dst_offset), qk);
-    }
-}
-
-static void ggml_compute_forward_dup(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (src0->type == dst->type) {
-        ggml_compute_forward_dup_bytes(params, dst);
-        return;
-    }
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_dup_f16(params, dst);
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                ggml_compute_forward_dup_bf16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_dup_f32(params, dst);
-            } break;
-        default:
-            {
-                if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_dup_q(params, dst);
-                    break;
-                }
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_add
-
-static void ggml_compute_forward_add_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(float)) {
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-            const int64_t nr0 = ne00 / ne10;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
-            for (int64_t r = 0; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
-                vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
-#else
-                ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
-            }
-        }
-    } else {
-        // src1 is not contiguous
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                const int64_t i10 = i0 % ne10;
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
-                dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_add_f16_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    if (dst->type == GGML_TYPE_F32) {
-        GGML_ASSERT( nb0 == sizeof(float));
-    }
-    else {
-        GGML_ASSERT(dst->type  == GGML_TYPE_F16);
-        GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
-    }
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(float)) {
-        if (dst->type == GGML_TYPE_F16) {
-            for (int ir = ir0; ir < ir1; ++ir) {
-                // src0, src1 and dst are same shape => same indices
-                const int i3 = ir/(ne2*ne1);
-                const int i2 = (ir - i3*ne2*ne1)/ne1;
-                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-                ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-                for (int i = 0; i < ne0; i++) {
-                    dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
-                }
-            }
-        } else {
-            for (int ir = ir0; ir < ir1; ++ir) {
-                // src0, src1 and dst are same shape => same indices
-                const int i3 = ir/(ne2*ne1);
-                const int i2 = (ir - i3*ne2*ne1)/ne1;
-                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-                float *       dst_ptr  = (float *)       ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-                for (int i = 0; i < ne0; i++) {
-                    dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
-                }
-            }
-        }
-    }
-    else {
-        // src1 is not contiguous
-        GGML_ABORT("fatal error");
-    }
-}
-
-static void ggml_compute_forward_add_bf16_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    if (dst->type == GGML_TYPE_F32) {
-        GGML_ASSERT( nb0 == sizeof(float));
-    }
-    else {
-        GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
-        GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
-    }
-
-    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(float)) {
-        if (dst->type == GGML_TYPE_BF16) {
-            for (int ir = ir0; ir < ir1; ++ir) {
-                // src0, src1 and dst are same shape => same indices
-                const int i3 = ir/(ne2*ne1);
-                const int i2 = (ir - i3*ne2*ne1)/ne1;
-                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-                ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-                ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-                for (int i = 0; i < ne0; i++) {
-                    dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
-                }
-            }
-        } else {
-            for (int ir = ir0; ir < ir1; ++ir) {
-                // src0, src1 and dst are same shape => same indices
-                const int i3 = ir/(ne2*ne1);
-                const int i2 = (ir - i3*ne2*ne1)/ne1;
-                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-                float *       dst_ptr  = (float *)       ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-                ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-                for (int i = 0; i < ne0; i++) {
-                    dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
-                }
-            }
-        }
-    }
-    else {
-        // src1 is not contiguous
-        GGML_ABORT("fatal error");
-    }
-}
-
-static void ggml_compute_forward_add_f16_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(ggml_fp16_t)) {
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-            for (int i = 0; i < ne0; i++) {
-                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
-            }
-        }
-    }
-    else {
-        // src1 is not contiguous
-        GGML_ABORT("fatal error");
-    }
-}
-
-static void ggml_compute_forward_add_bf16_bf16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
-    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
-    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(ggml_bf16_t)) {
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-            ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-            ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-            for (int i = 0; i < ne0; i++) {
-                dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
-            }
-        }
-    }
-    else {
-        // src1 is not contiguous
-        GGML_ABORT("fatal error");
-    }
-}
-
-static void ggml_compute_forward_add_q_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const enum ggml_type type = src0->type;
-    const enum ggml_type dtype = dst->type;
-    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
-
-    // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == ggml_type_size(type));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    GGML_ASSERT(ggml_is_quantized(src0->type));
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 indices
-        const int i03 = ir/(ne02*ne01);
-        const int i02 = (ir - i03*ne02*ne01)/ne01;
-        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-        // src1 and dst are same shape as src0 => same indices
-        const int i13 = i03;
-        const int i12 = i02;
-        const int i11 = i01;
-
-        const int i3 = i03;
-        const int i2 = i02;
-        const int i1 = i01;
-
-        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
-        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
-
-        assert(ne00 % 32 == 0);
-
-        // unquantize row from src0 to temp buffer
-        dequantize_row_q(src0_row, wdata, ne00);
-        // add src1
-        ggml_vec_acc_f32(ne00, wdata, src1_row);
-        // quantize row to dst
-        if (quantize_row_q != NULL) {
-            quantize_row_q(wdata, dst_row, ne00);
-        } else {
-            memcpy(dst_row, wdata, ne0*nb0);
-        }
-    }
-}
-
-static void ggml_compute_forward_add(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                if (src1->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_add_f32(params, dst);
-                }
-                else {
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_TYPE_F16:
-            {
-                if (src1->type == GGML_TYPE_F16) {
-                    ggml_compute_forward_add_f16_f16(params, dst);
-                }
-                else if (src1->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_add_f16_f32(params, dst);
-                }
-                else {
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                if (src1->type == GGML_TYPE_BF16) {
-                    ggml_compute_forward_add_bf16_bf16(params, dst);
-                }
-                else if (src1->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_add_bf16_f32(params, dst);
-                }
-                else {
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-            {
-                ggml_compute_forward_add_q_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_add1
-
-static void ggml_compute_forward_add1_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-#ifdef GGML_USE_ACCELERATE
-        UNUSED(ggml_vec_add1_f32);
-
-        vDSP_vadd(
-                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                (float *) ((char *) src1->data), 0,
-                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                ne0);
-#else
-        ggml_vec_add1_f32(ne0,
-                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-               *(float *) src1->data);
-#endif
-    }
-}
-
-static void ggml_compute_forward_add1_f16_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    // scalar to add
-    const float v = *(float *) src1->data;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-        for (int i = 0; i < ne0; i++) {
-            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
-        }
-    }
-}
-
-static void ggml_compute_forward_add1_f16_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    // scalar to add
-    const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-        for (int i = 0; i < ne0; i++) {
-            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
-        }
-    }
-}
-
-static void ggml_compute_forward_add1_q_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    // scalar to add
-    const float v = *(float *) src1->data;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const enum ggml_type type = src0->type;
-    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
-
-    // we don't support permuted src0
-    GGML_ASSERT(nb00 == ggml_type_size(type));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    GGML_ASSERT(ggml_is_quantized(src0->type));
-    GGML_ASSERT(dst->type == src0->type);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
-        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));
-
-        assert(ne0 % 32 == 0);
-
-        // unquantize row from src0 to temp buffer
-        dequantize_row_q(src0_row, wdata, ne0);
-        // add src1
-        ggml_vec_acc1_f32(ne0, wdata, v);
-        // quantize row to dst
-        quantize_row_q(wdata, dst_row, ne0);
-    }
-}
-
-static void ggml_compute_forward_add1_bf16_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    // scalar to add
-    const float v = *(float *) src1->data;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-        for (int i = 0; i < ne0; i++) {
-            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
-        }
-    }
-}
-
-static void ggml_compute_forward_add1_bf16_bf16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    // scalar to add
-    const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
-    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
-    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-        for (int i = 0; i < ne0; i++) {
-            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
-        }
-    }
-}
-
-static void ggml_compute_forward_add1(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_add1_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-            {
-                if (src1->type == GGML_TYPE_F16) {
-                    ggml_compute_forward_add1_f16_f16(params, dst);
-                }
-                else if (src1->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_add1_f16_f32(params, dst);
-                }
-                else {
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                if (src1->type == GGML_TYPE_BF16) {
-                    ggml_compute_forward_add1_bf16_bf16(params, dst);
-                }
-                else if (src1->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_add1_bf16_f32(params, dst);
-                }
-                else {
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-            {
-                ggml_compute_forward_add1_q_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_acc
-
-static void ggml_compute_forward_acc_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-
-    // view src0 and dst with these strides and data offset inbytes during acc
-    // nb0 is implicitly element_size because src0 and dst are contiguous
-    size_t nb1     = ((int32_t *) dst->op_params)[0];
-    size_t nb2     = ((int32_t *) dst->op_params)[1];
-    size_t nb3     = ((int32_t *) dst->op_params)[2];
-    size_t offset  = ((int32_t *) dst->op_params)[3];
-    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
-
-    if (!inplace) {
-        if (params->ith == 0) {
-            // memcpy needs to be synchronized across threads to avoid race conditions.
-            // => do it in INIT phase
-            memcpy(
-                ((char *)  dst->data),
-                ((char *) src0->data),
-                ggml_nbytes(dst));
-        }
-        ggml_barrier(params->threadpool);
-    }
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(src1);
-    const int nc = src1->ne[0];
-
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
-
-    // src0 and dst as viewed during acc
-    const size_t nb0 = ggml_element_size(src0);
-
-    const size_t nb00 = nb0;
-    const size_t nb01 = nb1;
-    const size_t nb02 = nb2;
-    const size_t nb03 = nb3;
-
-    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));
-    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
-
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are viewed with shape of src1 and offset
-        // => same indices
-        const int i3 = ir/(ne12*ne11);
-        const int i2 = (ir - i3*ne12*ne11)/ne11;
-        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
-
-#ifdef GGML_USE_ACCELERATE
-        vDSP_vadd(
-                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
-                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);
-#else
-        ggml_vec_add_f32(nc,
-                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
-                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
-                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
-#endif
-    }
-}
-
-static void ggml_compute_forward_acc(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_acc_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sub
-
-static void ggml_compute_forward_sub_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(float)) {
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-            const int64_t nr0 = ne00 / ne10;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
-            for (int64_t r = 0; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
-                vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
-#else
-                ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
-            }
-        }
-    } else {
-        // src1 is not contiguous
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                const int64_t i10 = i0 % ne10;
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
-                dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_sub(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sub_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_mul
-
-static void ggml_compute_forward_mul_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t nr = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    if (nb10 == sizeof(float)) {
-        for (int64_t ir = ith; ir < nr; ir += nth) {
-            // src0 and dst are same shape => same indices
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-            const int64_t nr0 = ne00 / ne10;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
-            for (int64_t r = 0 ; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
-                UNUSED(ggml_vec_mul_f32);
-
-                vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
-#else
-                ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
-            }
-        }
-    } else {
-        // src1 is not contiguous
-        for (int64_t ir = ith; ir < nr; ir += nth) {
-            // src0 and dst are same shape => same indices
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
-            for (int64_t i0 = 0; i0 < ne00; ++i0) {
-                const int64_t i10 = i0 % ne10;
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
-                dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_mul(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now");
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_mul_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_div
-
-static void ggml_compute_forward_div_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t nr = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    if (nb10 == sizeof(float)) {
-        for (int64_t ir = ith; ir < nr; ir += nth) {
-            // src0 and dst are same shape => same indices
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-            const int64_t nr0 = ne00 / ne10;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
-            for (int64_t r = 0; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
-                UNUSED(ggml_vec_div_f32);
-
-                vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
-#else
-                ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
-            }
-        }
-    } else {
-        // src1 is not contiguous
-        for (int64_t ir = ith; ir < nr; ir += nth) {
-            // src0 and dst are same shape => same indices
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
-            for (int64_t i0 = 0; i0 < ne00; ++i0) {
-                const int64_t i10 = i0 % ne10;
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
-                dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_div(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_div_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sqr
-
-static void ggml_compute_forward_sqr_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n     = ggml_nrows(src0);
-    const int nc    = src0->ne[0];
-
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sqr_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_sqr(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sqr_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sqrt
-
-static void ggml_compute_forward_sqrt_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sqrt_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_sqrt(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sqrt_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_log
-
-static void ggml_compute_forward_log_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_log_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_log(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_log_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sin
-
-static void ggml_compute_forward_sin_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sin_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_sin(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sin_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_cos
-
-static void ggml_compute_forward_cos_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_cos_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_cos(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_cos_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sum
-
-static void ggml_compute_forward_sum_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_scalar(dst));
-    assert(src0->nb[0] == sizeof(float));
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
-
-    ggml_float sum     = 0;
-    ggml_float row_sum = 0;
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                ggml_vec_sum_f32_ggf(ne00,
-                        &row_sum,
-                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
-                sum += row_sum;
-            }
-        }
-    }
-    ((float *) dst->data)[0] = sum;
-}
-
-static void ggml_compute_forward_sum_f16(
-    const struct ggml_compute_params * params,
-          struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_scalar(dst));
-
-    assert(src0->nb[0] == sizeof(ggml_fp16_t));
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
-
-    float sum = 0;
-    float row_sum = 0;
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                ggml_vec_sum_f16_ggf(ne00,
-                    &row_sum,
-                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
-                sum += row_sum;
-            }
-        }
-    }
-    ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
-}
-
-static void ggml_compute_forward_sum_bf16(
-    const struct ggml_compute_params * params,
-          struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_scalar(dst));
-
-    assert(src0->nb[0] == sizeof(ggml_bf16_t));
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
-
-    float sum = 0;
-    float row_sum = 0;
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                ggml_vec_sum_bf16_ggf(ne00,
-                    &row_sum,
-                    (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
-                sum += row_sum;
-            }
-        }
-    }
-    ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
-}
-
-static void ggml_compute_forward_sum(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sum_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_sum_f16(params, dst);
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                ggml_compute_forward_sum_bf16(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sum_rows
-
-static void ggml_compute_forward_sum_rows_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT(dst->nb[0] == sizeof(float));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(ne0 == 1);
-    GGML_ASSERT(ne1 == ne01);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
-
-    for (int64_t i3 = 0; i3 < ne03; i3++) {
-        for (int64_t i2 = 0; i2 < ne02; i2++) {
-            for (int64_t i1 = 0; i1 < ne01; i1++) {
-                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
-                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
-                float row_sum = 0;
-                ggml_vec_sum_f32(ne00, &row_sum, src_row);
-                dst_row[0] = row_sum;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_sum_rows(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sum_rows_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_mean
-
-static void ggml_compute_forward_mean_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(src0->nb[0] == sizeof(float));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    assert(ne0 == 1);
-    assert(ne1 == ne01);
-    assert(ne2 == ne02);
-    assert(ne3 == ne03);
-
-    UNUSED(ne0);
-    UNUSED(ne1);
-    UNUSED(ne2);
-    UNUSED(ne3);
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                ggml_vec_sum_f32(ne00,
-                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
-
-                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_mean(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_mean_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_argmax
-
-static void ggml_compute_forward_argmax_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(src0->nb[0] == sizeof(float));
-    assert(dst->nb[0] == sizeof(float));
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-
-    const size_t nb01 = src0->nb[1];
-    const size_t nb0 = dst->nb[0];
-
-    for (int64_t i1 = 0; i1 < ne01; i1++) {
-        float * src = (float *) ((char *) src0->data + i1*nb01);
-        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);
-        int v = 0;
-        ggml_vec_argmax_f32(ne00, &v, src);
-        dst_[0] = v;
-    }
-}
-
-static void ggml_compute_forward_argmax(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_argmax_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_count_equal
-
-static void ggml_compute_forward_count_equal_i32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    GGML_ASSERT(src0->type == GGML_TYPE_I32);
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_are_same_shape(src0, src1));
-    GGML_ASSERT(ggml_is_scalar(dst));
-    GGML_ASSERT(dst->type == GGML_TYPE_I64);
-
-    const int64_t nr = ggml_nrows(src0);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    int64_t * sums = (int64_t *) params->wdata;
-    int64_t sum_thread = 0;
-
-    // rows per thread
-    const int64_t dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int64_t ir0 = dr*ith;
-    const int64_t ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t ir = ir0; ir < ir1; ++ir) {
-        const int64_t i03 =  ir                        / (ne02*ne01);
-        const int64_t i02 = (ir - i03*ne03)            /       ne01;
-        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
-
-        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
-        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
-
-        for (int64_t i00 = 0; i00 < ne00; ++i00) {
-            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
-            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
-
-            sum_thread += val0 == val1;
-        }
-    }
-    if (ith != 0) {
-        sums[ith] = sum_thread;
-    }
-    ggml_barrier(params->threadpool);
-
-    if (ith != 0) {
-        return;
-    }
-
-    for (int ith_other = 1; ith_other < nth; ++ith_other) {
-        sum_thread += sums[ith_other];
-    }
-    *((int64_t *) dst->data) = sum_thread;
-}
-
-static void ggml_compute_forward_count_equal(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_I32:
-            {
-                ggml_compute_forward_count_equal_i32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_repeat
-
-static void ggml_compute_forward_repeat_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_can_repeat(src0, dst));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    // guaranteed to be an integer due to the check in ggml_can_repeat
-    const int nr0 = (int)(ne0/ne00);
-    const int nr1 = (int)(ne1/ne01);
-    const int nr2 = (int)(ne2/ne02);
-    const int nr3 = (int)(ne3/ne03);
-
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // TODO: maybe this is not optimal?
-    for                         (int i3 = 0; i3 < nr3;  i3++) {
-        for                     (int k3 = 0; k3 < ne03; k3++) {
-            for                 (int i2 = 0; i2 < nr2;  i2++) {
-                for             (int k2 = 0; k2 < ne02; k2++) {
-                    for         (int i1 = 0; i1 < nr1;  i1++) {
-                        for     (int k1 = 0; k1 < ne01; k1++) {
-                            for (int i0 = 0; i0 < nr0;  i0++) {
-                                ggml_vec_cpy_f32(ne00,
-                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),
-                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_repeat_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_can_repeat(src0, dst));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    // guaranteed to be an integer due to the check in ggml_can_repeat
-    const int nr0 = (int)(ne0/ne00);
-    const int nr1 = (int)(ne1/ne01);
-    const int nr2 = (int)(ne2/ne02);
-    const int nr3 = (int)(ne3/ne03);
-
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // TODO: maybe this is not optimal?
-    for                         (int i3 = 0; i3 < nr3;  i3++) {
-        for                     (int k3 = 0; k3 < ne03; k3++) {
-            for                 (int i2 = 0; i2 < nr2;  i2++) {
-                for             (int k2 = 0; k2 < ne02; k2++) {
-                    for         (int i1 = 0; i1 < nr1;  i1++) {
-                        for     (int k1 = 0; k1 < ne01; k1++) {
-                            for (int i0 = 0; i0 < nr0;  i0++) {
-                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
-                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
-                                // ggml_vec_cpy_f16(ne00, y, x)
-                                for (int i = 0; i < ne00; ++i) {
-                                    y[i]  = x[i];
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_repeat(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-        case GGML_TYPE_I16:
-            {
-                ggml_compute_forward_repeat_f16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-        case GGML_TYPE_I32:
-            {
-                ggml_compute_forward_repeat_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_repeat_back
-
-static void ggml_compute_forward_repeat_back_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_can_repeat(dst, src0));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    // guaranteed to be an integer due to the check in ggml_can_repeat
-    const int nr0 = (int)(ne00/ne0);
-    const int nr1 = (int)(ne01/ne1);
-    const int nr2 = (int)(ne02/ne2);
-    const int nr3 = (int)(ne03/ne3);
-
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    if (ggml_is_contiguous(dst)) {
-        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
-    } else {
-        for         (int k3 = 0; k3 < ne3; k3++) {
-            for     (int k2 = 0; k2 < ne2; k2++) {
-                for (int k1 = 0; k1 < ne1; k1++) {
-                    ggml_vec_set_f32(ne0,
-                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
-                        0);
-                }
-            }
-        }
-    }
-
-    // TODO: maybe this is not optimal?
-    for                         (int i3 = 0; i3 < nr3; i3++) {
-        for                     (int k3 = 0; k3 < ne3; k3++) {
-            for                 (int i2 = 0; i2 < nr2; i2++) {
-                for             (int k2 = 0; k2 < ne2; k2++) {
-                    for         (int i1 = 0; i1 < nr1; i1++) {
-                        for     (int k1 = 0; k1 < ne1; k1++) {
-                            for (int i0 = 0; i0 < nr0; i0++) {
-                                ggml_vec_acc_f32(ne0,
-                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
-                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_repeat_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_repeat_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_concat
-
-static void ggml_compute_forward_concat_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int32_t dim = ggml_get_op_params_i32(dst, 0);
-
-    GGML_ASSERT(dim >= 0 && dim < 4);
-
-    int64_t o[4] = {0, 0, 0, 0};
-    o[dim] = src0->ne[dim];
-
-    const float * x;
-
-    // TODO: smarter multi-theading
-    for (int i3 = 0; i3 < ne3; i3++) {
-        for (int i2 = ith; i2 < ne2; i2 += nth) {
-            for (int i1 = 0; i1 < ne1; i1++) {
-                for (int i0 = 0; i0 < ne0; i0++) {
-                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
-                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
-                    } else {
-                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
-                    }
-
-                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
-
-                    *y = *x;
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_concat(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-        case GGML_TYPE_I32:
-            {
-                ggml_compute_forward_concat_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_abs
-
-static void ggml_compute_forward_abs_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_abs_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_abs(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_abs_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sgn
-
-static void ggml_compute_forward_sgn_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sgn_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_sgn(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sgn_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_neg
-
-static void ggml_compute_forward_neg_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_neg_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_neg(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_neg_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_step
-
-static void ggml_compute_forward_step_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_step_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_step(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_step_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_tanh
-
-static void ggml_compute_forward_tanh_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_tanh_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_tanh(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_tanh_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_elu
-
-static void ggml_compute_forward_elu_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_elu_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_elu(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_elu_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_relu
-
-static void ggml_compute_forward_relu_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_relu_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_relu(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_relu_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sigmoid
-
-static void ggml_compute_forward_sigmoid_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sigmoid_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_sigmoid(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sigmoid_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_gelu
-
-static void ggml_compute_forward_gelu_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_gelu_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
-
-#ifndef NDEBUG
-        for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
-            UNUSED(x);
-            assert(!isnan(x));
-            assert(!isinf(x));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_gelu(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_gelu_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_gelu_quick
-
-static void ggml_compute_forward_gelu_quick_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_gelu_quick_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
-
-#ifndef NDEBUG
-        for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
-            UNUSED(x);
-            assert(!isnan(x));
-            assert(!isinf(x));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_gelu_quick(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_gelu_quick_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_silu
-
-static void ggml_compute_forward_silu_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_silu_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
-
-#ifndef NDEBUG
-        for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
-            UNUSED(x);
-            assert(!isnan(x));
-            assert(!isinf(x));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_silu(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_silu_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-// ggml_compute_forward_leaky_relu
-
-static void ggml_compute_forward_leaky_relu_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    float negative_slope;
-    memcpy(&negative_slope, dst->op_params, sizeof(float));
-
-    assert(dst->nb[0]  == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_leaky_relu_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
-    }
-}
-
-static void ggml_compute_forward_leaky_relu(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_leaky_relu_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_silu_back
-
-static void ggml_compute_forward_silu_back_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * grad = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    assert(ggml_is_contiguous_1(grad));
-    assert(ggml_is_contiguous_1(src1));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src1, dst));
-    assert(ggml_are_same_shape(src1, grad));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src1->ne[0];
-    const int nr = ggml_nrows(src1);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_silu_backward_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src1->data + i1*(src1->nb[1])),
-                (float *) ((char *) grad->data + i1*(grad->nb[1])));
-
-#ifndef NDEBUG
-        for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
-            UNUSED(x);
-            assert(!isnan(x));
-            assert(!isinf(x));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_silu_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_silu_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-
-static void ggml_compute_forward_hardswish_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_hardswish_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-static void ggml_compute_forward_hardswish(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_hardswish_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_hardsigmoid_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_hardsigmoid_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_hardsigmoid(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_hardsigmoid_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_exp_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_exp_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_exp(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_exp_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-
-// ggml_compute_forward_norm
-
-static void ggml_compute_forward_norm_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    float eps;
-    memcpy(&eps, dst->op_params, sizeof(float));
-
-    GGML_ASSERT(eps >= 0.0f);
-
-    // TODO: optimize
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
-                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
-                ggml_float sum = 0.0;
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    sum += (ggml_float)x[i00];
-                }
-
-                float mean = sum/ne00;
-
-                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
-
-                ggml_float sum2 = 0.0;
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    float v = x[i00] - mean;
-                    y[i00] = v;
-                    sum2 += (ggml_float)(v*v);
-                }
-
-                float variance = sum2/ne00;
-                const float scale = 1.0f/sqrtf(variance + eps);
-
-                ggml_vec_scale_f32(ne00, y, scale);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_norm(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_norm_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_group_rms_norm
-
-static void ggml_compute_forward_rms_norm_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    float eps;
-    memcpy(&eps, dst->op_params, sizeof(float));
-
-    GGML_ASSERT(eps >= 0.0f);
-
-    // TODO: optimize
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
-                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
-                ggml_float sum = 0.0;
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    sum += (ggml_float)(x[i00] * x[i00]);
-                }
-
-                const float mean = sum/ne00;
-
-                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
-
-                memcpy(y, x, ne00 * sizeof(float));
-                // for (int i00 = 0; i00 < ne00; i00++) {
-                //     y[i00] = x[i00];
-                // }
-
-                const float scale = 1.0f/sqrtf(mean + eps);
-
-                ggml_vec_scale_f32(ne00, y, scale);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_rms_norm(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_rms_norm_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_rms_norm_back_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
-    const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT(src1->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    float eps;
-    memcpy(&eps, dst->op_params, sizeof(float));
-
-    // TODO: optimize
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
-                // src1 is same shape as src0 => same indices
-                const int64_t i11 = i01;
-                const int64_t i12 = i02;
-                const int64_t i13 = i03;
-
-                const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                const float * x  = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
-
-                ggml_float sum_xx  = 0.0;
-                ggml_float sum_xdz = 0.0;
-
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    sum_xx  += (ggml_float)(x[i00] * x[i00]);
-                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);
-                }
-
-                //const float mean     = (float)(sum_xx)/ne00;
-                const float mean_eps = (float)(sum_xx)/ne00 + eps;
-                const float sum_eps  = (float)(sum_xx) + eps*ne00;
-                //const float mean_xdz = (float)(sum_xdz)/ne00;
-                // we could cache rms from forward pass to improve performance.
-                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
-                //const float rms      = sqrtf(mean_eps);
-                const float rrms     = 1.0f / sqrtf(mean_eps);
-                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
-
-                {
-                    // z = rms_norm(x)
-                    //
-                    // rms_norm(src1) =
-                    //     scale(
-                    //         src1,
-                    //         div(
-                    //             1,
-                    //             sqrt(
-                    //                 add(
-                    //                     scale(
-                    //                         sum(
-                    //                             sqr(
-                    //                                 src1)),
-                    //                         (1.0/N)),
-                    //                     eps))));
-
-                    // postorder:
-                    // ## op    args         grad
-                    // 00 param src1         grad[#00]
-                    // 01 const 1
-                    // 02 sqr   (#00)        grad[#02]
-                    // 03 sum   (#02)        grad[#03]
-                    // 04 const 1/N
-                    // 05 scale (#03, #04)   grad[#05]
-                    // 06 const eps
-                    // 07 add   (#05, #06)   grad[#07]
-                    // 08 sqrt  (#07)        grad[#08]
-                    // 09 div   (#01,#08)    grad[#09]
-                    // 10 scale (#00,#09)    grad[#10]
-                    //
-                    // backward pass, given grad[#10]
-                    // #10: scale
-                    // grad[#00] += scale(grad[#10],#09)
-                    // grad[#09] += sum(mul(grad[#10],#00))
-                    // #09: div
-                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
-                    // #08: sqrt
-                    // grad[#07] += mul(grad[#08], div(0.5, #08))
-                    // #07: add
-                    // grad[#05] += grad[#07]
-                    // #05: scale
-                    // grad[#03] += scale(grad[#05],#04)
-                    // #03: sum
-                    // grad[#02] += repeat(grad[#03], #02)
-                    // #02:
-                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
-                    //
-                    // substitute and simplify:
-                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
-                    // grad[#02] = repeat(grad[#03], #02)
-                    // grad[#02] = repeat(scale(grad[#05],#04), #02)
-                    // grad[#02] = repeat(scale(grad[#07],#04), #02)
-                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
-                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
-                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
-                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
-                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
-                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
-                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
-                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
-                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
-                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
-                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
-                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
-                    // a = b*c + d*e
-                    // a = b*c*f/f + d*e*f/f
-                    // a = (b*c*f + d*e*f)*(1/f)
-                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
-                    // a = (b + d*e/c)*c
-                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
-                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
-                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
-                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
-                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
-                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
-                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
-                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
-                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
-                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
-                }
-                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
-                // post-order:
-                // dx := x
-                // dx := scale(dx,-mean_xdz/mean_eps)
-                // dx := add(dx, dz)
-                // dx := scale(dx, rrms)
-                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
-
-                // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
-                ggml_vec_cpy_f32  (ne00, dx, x);
-                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
-                ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
-                ggml_vec_acc_f32  (ne00, dx, dz);
-                ggml_vec_scale_f32(ne00, dx, rrms);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_rms_norm_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_rms_norm_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_group_norm
-
-static void ggml_compute_forward_group_norm_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    // TODO: optimize
-
-    float eps;
-    memcpy(&eps, dst->op_params + 1, sizeof(float));
-
-    int n_channels = src0->ne[2];
-    int n_groups = dst->op_params[0];
-    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
-    for (int i = ith; i < n_groups; i += nth) {
-        int start = i * n_channels_per_group;
-        int end = start + n_channels_per_group;
-        if (end > n_channels) {
-            end = n_channels;
-        }
-        int step = end - start;
-
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            ggml_float sum = 0.0;
-            for (int64_t i02 = start; i02 < end; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
-
-                    ggml_float sumr = 0.0;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        sumr += (ggml_float)x[i00];
-                    }
-                    sum += sumr;
-                }
-            }
-            const float mean = sum / (ne00 * ne01 * step);
-
-            ggml_float sum2 = 0.0;
-            for (int64_t i02 = start; i02 < end; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
-
-                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
-
-                    ggml_float sumr = 0.0;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        float v = x[i00] - mean;
-                        y[i00] = v;
-                        sumr += (ggml_float)(v * v);
-                    }
-                    sum2 += sumr;
-                }
-            }
-            const float variance = sum2 / (ne00 * ne01 * step);
-            const float scale = 1.0f / sqrtf(variance + eps);
-
-            for (int64_t i02 = start; i02 < end; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
-                    ggml_vec_scale_f32(ne00, y, scale);
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_group_norm(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_group_norm_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
 // ggml_compute_forward_mul_mat
 
 static void ggml_compute_forward_mul_mat_one_chunk(
@@ -7979,4982 +1719,6 @@ static void ggml_compute_forward_mul_mat_id(
     }
 }
 
-// ggml_compute_forward_out_prod
-
-static void ggml_compute_forward_out_prod_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_ASSERT(ne0 == ne00);
-    GGML_ASSERT(ne1 == ne10);
-    GGML_ASSERT(ne2 == ne12);
-    GGML_ASSERT(ne3 == ne13);
-
-    GGML_ASSERT(ne2 % ne02 == 0);
-    GGML_ASSERT(ne3 % ne03 == 0);
-
-    // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    // GGML_ASSERT(nb0 <= nb1);
-    // GGML_ASSERT(nb1 <= nb2);
-    // GGML_ASSERT(nb2 <= nb3);
-
-    // nb01 >= nb00 - src0 is not transposed
-    //   compute by src0 rows
-
-    if (ith == 0) {
-        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
-    }
-    ggml_barrier(params->threadpool);
-
-    // dst[:,:,:,:] = 0
-    // for i2,i3:
-    //   for i1:
-    //     for i01:
-    //       for i0:
-    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
-
-    // parallelize by last three dimensions
-
-    // total rows in dst
-    const int64_t nr = ne1*ne2*ne3;
-
-    // rows per thread
-    const int64_t dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int64_t ir0 = dr*ith;
-    const int64_t ir1 = MIN(ir0 + dr, nr);
-
-    // block-tiling attempt
-    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
-    const int64_t blck_1 = 16;
-
-    // dps == dst per src0, used for group query attention
-    const int64_t dps2 = ne2 / ne02;
-    const int64_t dps3 = ne3 / ne03;
-
-    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
-        const int64_t bir1 = MIN(bir + blck_1, ir1);
-        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
-            const int64_t bne01 = MIN(bi01 + blck_0, ne01);
-            for (int64_t ir = bir; ir < bir1; ++ir) {
-                // dst indices
-                const int64_t i3 = ir/(ne2*ne1);
-                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
-                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-                const int64_t i02 = i2 / dps2;
-                const int64_t i03 = i3 / dps3;
-
-                //const int64_t i10 = i1;
-                const int64_t i12 = i2;
-                const int64_t i13 = i3;
-
-#if GGML_VEC_MAD_UNROLL > 2
-                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
-                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
-                    const int64_t i11 = i01;
-
-                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
-                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
-
-                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
-                }
-                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
-                    const int64_t i11 = i01;
-
-                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
-                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
-
-                    ggml_vec_mad_f32(ne0, d, s0, *s1);
-                }
-#else
-                for (int64_t i01 = bi01; i01 < bne01; ++i01) {
-                    const int64_t i11 = i01;
-
-                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
-                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
-
-                    ggml_vec_mad_f32(ne0, d, s0, *s1);
-                }
-#endif
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_out_prod_q_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const enum ggml_type type = src0->type;
-    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne03 == ne13);
-    GGML_ASSERT(ne2  == ne12);
-    GGML_ASSERT(ne3  == ne13);
-
-    // we don't support permuted src0 dim0
-    GGML_ASSERT(nb00 == ggml_type_size(type));
-
-    // dst dim0 cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    // GGML_ASSERT(nb0 <= nb1);
-    // GGML_ASSERT(nb1 <= nb2);
-    // GGML_ASSERT(nb2 <= nb3);
-
-    GGML_ASSERT(ne0 == ne00);
-    GGML_ASSERT(ne1 == ne10);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
-
-    // nb01 >= nb00 - src0 is not transposed
-    //   compute by src0 rows
-
-    if (ith == 0) {
-        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
-    }
-    ggml_barrier(params->threadpool);
-
-    // parallelize by last three dimensions
-
-    // total rows in dst
-    const int64_t nr = ne1*ne2*ne3;
-
-    // rows per thread
-    const int64_t dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int64_t ir0 = dr*ith;
-    const int64_t ir1 = MIN(ir0 + dr, nr);
-
-    // dst[:,:,:,:] = 0
-    // for i2,i3:
-    //   for i1:
-    //     for i01:
-    //       for i0:
-    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
-
-    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
-
-    for (int64_t ir = ir0; ir < ir1; ++ir) {
-        // dst indices
-        const int64_t i3 = ir/(ne2*ne1);
-        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
-        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        const int64_t i02 = i2;
-        const int64_t i03 = i3;
-
-        //const int64_t i10 = i1;
-        const int64_t i12 = i2;
-        const int64_t i13 = i3;
-
-        for (int64_t i01 = 0; i01 < ne01; ++i01) {
-            const int64_t i11 = i01;
-
-            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
-            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
-
-            dequantize_row_q(s0, wdata, ne0);
-            ggml_vec_mad_f32(ne0, d, wdata, *s1);
-        }
-    }
-}
-
-static void ggml_compute_forward_out_prod(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-            {
-                ggml_compute_forward_out_prod_q_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-            {
-                GGML_ABORT("fatal error"); // todo
-                // ggml_compute_forward_out_prod_f16_f32(params, dst);
-            }
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_out_prod_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_scale
-
-static void ggml_compute_forward_scale_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    // scale factor
-    float v;
-    memcpy(&v, dst->op_params, sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    const size_t nb01 = src0->nb[1];
-
-    const size_t nb1 = dst->nb[1];
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        if (dst->data != src0->data) {
-            // src0 is same shape as dst => same indices
-            memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
-        }
-        ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
-    }
-}
-
-static void ggml_compute_forward_scale(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_scale_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_set
-
-static void ggml_compute_forward_set_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-
-    // view src0 and dst with these strides and data offset inbytes during set
-    // nb0 is implicitly element_size because src0 and dst are contiguous
-    size_t nb1     = ((int32_t *) dst->op_params)[0];
-    size_t nb2     = ((int32_t *) dst->op_params)[1];
-    size_t nb3     = ((int32_t *) dst->op_params)[2];
-    size_t offset  = ((int32_t *) dst->op_params)[3];
-    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
-
-    if (!inplace) {
-        if (params->ith == 0) {
-            // memcpy needs to be synchronized across threads to avoid race conditions.
-            // => do it in INIT phase
-            memcpy(
-                ((char *)  dst->data),
-                ((char *) src0->data),
-                ggml_nbytes(dst));
-        }
-        ggml_barrier(params->threadpool);
-    }
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(src1);
-    const int nc = src1->ne[0];
-
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
-
-    // src0 and dst as viewed during set
-    const size_t nb0 = ggml_element_size(src0);
-
-    const int im0 = (ne10 == 0 ? 0 : ne10-1);
-    const int im1 = (ne11 == 0 ? 0 : ne11-1);
-    const int im2 = (ne12 == 0 ? 0 : ne12-1);
-    const int im3 = (ne13 == 0 ? 0 : ne13-1);
-
-    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
-
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are viewed with shape of src1 and offset
-        // => same indices
-        const int i3 = ir/(ne12*ne11);
-        const int i2 = (ir - i3*ne12*ne11)/ne11;
-        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
-
-        ggml_vec_cpy_f32(nc,
-                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
-                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
-    }
-}
-
-static void ggml_compute_forward_set_i32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-
-    // view src0 and dst with these strides and data offset inbytes during set
-    // nb0 is implicitly element_size because src0 and dst are contiguous
-    size_t nb1     = ((int32_t *) dst->op_params)[0];
-    size_t nb2     = ((int32_t *) dst->op_params)[1];
-    size_t nb3     = ((int32_t *) dst->op_params)[2];
-    size_t offset  = ((int32_t *) dst->op_params)[3];
-    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
-
-    if (!inplace) {
-        if (params->ith == 0) {
-            // memcpy needs to be synchronized across threads to avoid race conditions.
-            // => do it in INIT phase
-            memcpy(
-                ((char *)  dst->data),
-                ((char *) src0->data),
-                ggml_nbytes(dst));
-        }
-        ggml_barrier(params->threadpool);
-    }
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(src1);
-    const int nc = src1->ne[0];
-
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
-
-    // src0 and dst as viewed during set
-    const size_t nb0 = ggml_element_size(src0);
-
-    const int im0 = (ne10 == 0 ? 0 : ne10-1);
-    const int im1 = (ne11 == 0 ? 0 : ne11-1);
-    const int im2 = (ne12 == 0 ? 0 : ne12-1);
-    const int im3 = (ne13 == 0 ? 0 : ne13-1);
-
-    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
-
-    GGML_ASSERT(nb10 == sizeof(int32_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are viewed with shape of src1 and offset
-        // => same indices
-        const int i3 = ir/(ne12*ne11);
-        const int i2 = (ir - i3*ne12*ne11)/ne11;
-        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
-
-        ggml_vec_cpy_i32(nc,
-                (int32_t *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
-                (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
-    }
-}
-
-static void ggml_compute_forward_set(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_set_f32(params, dst);
-            } break;
-        case GGML_TYPE_I32:
-            {
-                ggml_compute_forward_set_i32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_cpy
-
-static void ggml_compute_forward_cpy(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    ggml_compute_forward_dup(params, dst);
-}
-
-// ggml_compute_forward_cont
-
-static void ggml_compute_forward_cont(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    ggml_compute_forward_dup(params, dst);
-}
-
-// ggml_compute_forward_reshape
-
-static void ggml_compute_forward_reshape(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    // NOP
-    UNUSED(params);
-    UNUSED(dst);
-}
-
-// ggml_compute_forward_view
-
-static void ggml_compute_forward_view(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * dst) {
-    // NOP
-    UNUSED(params);
-    UNUSED(dst);
-}
-
-// ggml_compute_forward_permute
-
-static void ggml_compute_forward_permute(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * dst) {
-    // NOP
-    UNUSED(params);
-    UNUSED(dst);
-}
-
-// ggml_compute_forward_transpose
-
-static void ggml_compute_forward_transpose(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * dst) {
-    // NOP
-    UNUSED(params);
-    UNUSED(dst);
-}
-
-// ggml_compute_forward_get_rows
-
-static void ggml_compute_forward_get_rows_q(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1);
-
-    const enum ggml_type type = src0->type;
-    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-
-    assert(ne0  == nc);
-    assert(ne02 == ne11);
-    assert(nb00 == ggml_type_size(type));
-    assert(ggml_nrows(dst) == nr);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t i = ir0; i < ir1; ++i) {
-        const int64_t i12 = i/(ne11*ne10);
-        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
-        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
-        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
-        GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
-        dequantize_row_q(
-                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
-                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
-    }
-}
-
-static void ggml_compute_forward_get_rows_f16(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1);
-
-    assert(ne0  == nc);
-    assert(ne02 == ne11);
-    assert(nb00 == sizeof(ggml_fp16_t));
-    assert(ggml_nrows(dst) == nr);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t i = ir0; i < ir1; ++i) {
-        const int64_t i12 = i/(ne11*ne10);
-        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
-        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
-        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
-        GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
-        ggml_fp16_to_fp32_row(
-                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
-                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
-    }
-}
-
-static void ggml_compute_forward_get_rows_bf16(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1);
-
-    assert(ne0  == nc);
-    assert(ne02 == ne11);
-    assert(nb00 == sizeof(ggml_bf16_t));
-    assert(ggml_nrows(dst) == nr);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t i = ir0; i < ir1; ++i) {
-        const int64_t i12 = i/(ne11*ne10);
-        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
-        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
-        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
-        GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
-        ggml_bf16_to_fp32_row(
-                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
-                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
-    }
-}
-
-static void ggml_compute_forward_get_rows_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1);
-
-    assert(ne0  == nc);
-    assert(ne02 == ne11);
-    assert(nb00 == sizeof(float));
-    assert(ggml_nrows(dst) == nr);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t i = ir0; i < ir1; ++i) {
-        const int64_t i12 = i/(ne11*ne10);
-        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
-        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
-        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
-        GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
-        ggml_vec_cpy_f32(nc,
-                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
-                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
-    }
-}
-
-static void ggml_compute_forward_get_rows(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-            {
-                ggml_compute_forward_get_rows_q(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_get_rows_f16(params, dst);
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                ggml_compute_forward_get_rows_bf16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-        case GGML_TYPE_I32:
-            {
-                ggml_compute_forward_get_rows_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-
-    //static bool first = true;
-    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
-    //if (first) {
-    //    first = false;
-    //} else {
-    //    for (int k = 0; k < dst->ne[1]; ++k) {
-    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
-    //            for (int i = 0; i < 16; ++i) {
-    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
-    //            }
-    //            printf("\n");
-    //        }
-    //        printf("\n");
-    //    }
-    //    printf("\n");
-    //    exit(0);
-    //}
-}
-
-// ggml_compute_forward_get_rows_back
-
-static void ggml_compute_forward_get_rows_back_f32_f16(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_is_contiguous(dst));
-
-    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
-
-    memset(dst->data, 0, ggml_nbytes(dst));
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
-
-    GGML_ASSERT( dst->ne[0] == nc);
-    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
-
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
-
-        for (int j = 0; j < nc; ++j) {
-            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
-            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
-        }
-    }
-}
-
-static void ggml_compute_forward_get_rows_back_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_is_contiguous(dst));
-
-    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
-
-    memset(dst->data, 0, ggml_nbytes(dst));
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
-
-    GGML_ASSERT( dst->ne[0] == nc);
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
-
-        ggml_vec_add_f32(nc,
-                (float *) ((char *)  dst->data + r*dst->nb[1]),
-                (float *) ((char *)  dst->data + r*dst->nb[1]),
-                (float *) ((char *) src0->data + i*src0->nb[1]));
-    }
-}
-
-static void ggml_compute_forward_get_rows_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_get_rows_back_f32_f16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_get_rows_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-
-    //static bool first = true;
-    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
-    //if (first) {
-    //    first = false;
-    //} else {
-    //    for (int k = 0; k < dst->ne[1]; ++k) {
-    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
-    //            for (int i = 0; i < 16; ++i) {
-    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
-    //            }
-    //            printf("\n");
-    //        }
-    //        printf("\n");
-    //    }
-    //    printf("\n");
-    //    exit(0);
-    //}
-}
-
-// ggml_compute_forward_diag
-
-static void ggml_compute_forward_diag_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    // TODO: handle transposed/permuted matrices
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(ne00 == ne0);
-    GGML_ASSERT(ne00 == ne1);
-    GGML_ASSERT(ne01 == 1);
-    GGML_ASSERT(ne02 == ne2);
-    GGML_ASSERT(ne03 == ne3);
-
-    GGML_ASSERT(nb00 == sizeof(float));
-    GGML_ASSERT(nb0  == sizeof(float));
-
-    for (int i3 = 0; i3 < ne3; i3++) {
-        for (int i2 = 0; i2 < ne2; i2++) {
-            for (int i1 = 0; i1 < ne1; i1++) {
-                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);
-                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
-                for (int i0 = 0; i0 < i1; i0++) {
-                    d[i0] = 0;
-                }
-                d[i1] = s[i1];
-                for (int i0 = i1+1; i0 < ne0; i0++) {
-                    d[i0] = 0;
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_diag(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_diag_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_diag_mask_inf
-
-static void ggml_compute_forward_diag_mask_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const float value) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int  n_past  = ((int32_t *) dst->op_params)[0];
-    const bool inplace = src0->data == dst->data;
-
-    GGML_ASSERT(n_past >= 0);
-
-    if (!inplace) {
-        if (ith == 0) {
-            // memcpy needs to be synchronized across threads to avoid race conditions.
-            // => do it in INIT phase
-            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-            memcpy(
-                ((char *)  dst->data),
-                ((char *) src0->data),
-                ggml_nbytes(dst));
-        }
-        ggml_barrier(params->threadpool);
-    }
-
-    // TODO: handle transposed/permuted matrices
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-    const int nr = src0->ne[1];
-    const int nz = n/nr;
-
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    for (int k = 0; k < nz; k++) {
-        for (int j = ith; j < nr; j += nth) {
-            for (int i = n_past; i < nc; i++) {
-                if (i > n_past + j) {
-                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_diag_mask_inf(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_diag_mask_zero(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_diag_mask_f32(params, dst, 0);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_soft_max
-
-static void ggml_compute_forward_soft_max_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    assert(ggml_is_contiguous(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
-
-    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
-
-    // TODO: handle transposed/permuted matrices
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //const int64_t ne11 = src1 ? src1->ne[1] : 1;
-
-    // TODO: is this supposed to be ceil instead of floor?
-    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
-    const uint32_t n_head      = ne02;
-    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
-
-    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        // ALiBi
-        const uint32_t h = (i1/ne01)%ne02; // head
-        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
-
-        float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
-
-        // broadcast the mask across rows
-        ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
-        float       * mp_f32 = src1 ? (float       *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
-
-        ggml_vec_cpy_f32  (nc, wp, sp);
-        ggml_vec_scale_f32(nc, wp, scale);
-        if (mp_f32) {
-            if (use_f16) {
-                for (int i = 0; i < nc; ++i) {
-                    wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
-                }
-            } else {
-                for (int i = 0; i < nc; ++i) {
-                    wp[i] += slope*mp_f32[i];
-                }
-            }
-        }
-
-#ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(wp[i]));
-        }
-#endif
-
-        float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, wp);
-
-        ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
-        assert(sum > 0.0);
-
-        sum = 1.0/sum;
-        ggml_vec_scale_f32(nc, dp, sum);
-
-#ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            assert(!isnan(dp[i]));
-            assert(!isinf(dp[i]));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_soft_max(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_soft_max_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-
-// ggml_compute_forward_soft_max_ext_back
-
-static void ggml_compute_forward_soft_max_ext_back_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_are_same_shape(src1, dst));
-
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
-
-    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
-
-    GGML_ASSERT(max_bias == 0.0f);
-
-    // TODO: handle transposed/permuted matrices
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
-        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
-
-#ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(dy[i]));
-            assert(!isnan(y[i]));
-        }
-#endif
-        // Jii = yi - yi*yi
-        // Jij = -yi*yj
-        // J = diag(y)-y.T*y
-        // dx = J * dy
-        // dxk = sum_i(Jki * dyi)
-        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
-        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
-        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
-        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
-        // dxk = -yk * dot(y, dy) + yk*dyk
-        // dxk = yk * (- dot(y, dy) + dyk)
-        // dxk = yk * (dyk - dot(y, dy))
-        //
-        // post-order:
-        // dot_y_dy := dot(y, dy)
-        // dx := dy
-        // dx := dx - dot_y_dy
-        // dx := dx * y
-
-        // linear runtime, no additional memory
-        float dot_y_dy = 0;
-        ggml_vec_dot_f32  (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
-        ggml_vec_cpy_f32  (nc, dx, dy);
-        ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
-        ggml_vec_mul_f32  (nc, dx, dx, y);
-        ggml_vec_scale_f32(nc, dx, scale);
-
-#ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            assert(!isnan(dx[i]));
-            assert(!isinf(dx[i]));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_soft_max_ext_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_soft_max_ext_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_clamp
-
-static void ggml_compute_forward_clamp_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    float min;
-    float max;
-    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    for (int j = ith; j < n; j += nth) {
-        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
-        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
-
-        for (int i = 0; i < nc; i++) {
-            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
-        }
-    }
-}
-
-static void ggml_compute_forward_clamp(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_clamp_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_Q8_K:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_I64:
-        case GGML_TYPE_F64:
-        case GGML_TYPE_COUNT:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_rope
-
-static float rope_yarn_ramp(const float low, const float high, const int i0) {
-    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
-    return 1 - MIN(1, MAX(0, y));
-}
-
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-static void rope_yarn(
-    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta) {
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
-    }
-    *cos_theta = cosf(theta) * mscale;
-    *sin_theta = sinf(theta) * mscale;
-}
-
-static void ggml_rope_cache_init(
-     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
-     float * cache, float sin_sign, float theta_scale) {
-    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
-    float theta = theta_base;
-    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
-        rope_yarn(
-            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
-        );
-        cache[i0 + 1] *= sin_sign;
-
-        theta *= theta_scale;
-    }
-}
-
-static void ggml_mrope_cache_init(
-     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
-     float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
-     float * cache, float sin_sign, float theta_scale) {
-    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
-    float theta_t = theta_base_t;
-    float theta_h = theta_base_h;
-    float theta_w = theta_base_w;
-    float theta_e = theta_base_e;  // extra position id for vision encoder
-    int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
-    int sec_w = sections[1] + sections[0];
-    int sec_e = sections[2] + sec_w;
-    GGML_ASSERT(sect_dims <= ne0);
-
-    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
-
-        int sector = (i0 / 2) % sect_dims;
-        if (indep_sects) {
-            // compute theta independently for each dim sections
-            // (i.e. reset corresponding theta when `i0` go from one section to another)
-            if (sector == 0) {
-                theta_t = theta_base_t;
-            }
-            else if (sector == sections[0]) {
-                theta_h = theta_base_h;;
-            }
-            else if (sector == sec_w) {
-                theta_w = theta_base_w;
-            }
-            else if (sector == sec_e) {
-                theta_e = theta_base_e;
-            }
-        }
-
-        float theta = theta_t;
-        if (sector >= sections[0] && sector < sec_w) {
-            theta = theta_h;
-        }
-        else if (sector >= sec_w && sector < sec_w + sections[2]) {
-            theta = theta_w;
-        }
-        else if (sector >= sec_w + sections[2]) {
-            theta = theta_e;
-        }
-
-        rope_yarn(
-            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
-        );
-        cache[i0 + 1] *= sin_sign;
-
-        theta_t *= theta_scale;
-        theta_w *= theta_scale;
-        theta_h *= theta_scale;
-        theta_e *= theta_scale;
-    }
-}
-
-static void ggml_compute_forward_rope_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const bool forward) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * src2 = dst->src[2];
-
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-    int sections[4];
-
-    //const int n_past     = ((int32_t *) dst->op_params)[0];
-    const int n_dims     = ((int32_t *) dst->op_params)[1];
-    const int mode       = ((int32_t *) dst->op_params)[2];
-    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
-    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-
-    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-    memcpy(§ions,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
-    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(dst);
-
-    GGML_ASSERT(n_dims <= ne0);
-    GGML_ASSERT(n_dims % 2 == 0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    // row index used to determine which thread to use
-    int ir = 0;
-
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
-    float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
-    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, multimodal rotary position embedding
-    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
-
-    if (is_mrope) {
-        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
-    }
-
-    if (is_vision) {
-        GGML_ASSERT(n_dims == ne0/2);
-    }
-
-    const float * freq_factors = NULL;
-    if (src2 != NULL) {
-        GGML_ASSERT(src2->type == GGML_TYPE_F32);
-        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
-        freq_factors = (const float *) src2->data;
-    }
-
-    // backward process uses inverse rotation by cos and sin.
-    // cos and sin build a rotation matrix, where the inverse is the transpose.
-    // this essentially just switches the sign of sin.
-    const float sin_sign = forward ? 1.0f : -1.0f;
-
-    const int32_t * pos = (const int32_t *) src1->data;
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
-        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
-
-            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            if (!is_mrope) {
-                const int64_t p = pos[i2];
-                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
-            else {
-                const int64_t p_t = pos[i2];
-                const int64_t p_h = pos[i2 + ne2];
-                const int64_t p_w = pos[i2 + ne2 * 2];
-                const int64_t p_e = pos[i2 + ne2 * 3];
-                ggml_mrope_cache_init(
-                    p_t, p_h, p_w, p_e, sections, is_vision,
-                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
-
-            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
-                if (ir++ < ir0) continue;
-                if (ir   > ir1) break;
-
-                if (is_neox || is_mrope) {
-                    if (is_vision){
-                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                            const int64_t ic = i0/2;
-
-                            const float cos_theta = cache[i0 + 0];
-                            const float sin_theta = cache[i0 + 1];
-
-                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                            float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                            const float x0 = src[0];
-                            const float x1 = src[n_dims];
-
-                            dst_data[0]      = x0*cos_theta - x1*sin_theta;
-                            dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
-                        }
-                    } else {
-                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                            const int64_t ic = i0/2;
-
-                            const float cos_theta = cache[i0 + 0];
-                            const float sin_theta = cache[i0 + 1];
-
-                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                            float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                            const float x0 = src[0];
-                            const float x1 = src[n_dims/2];
-
-                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
-                        }
-                    }
-                } else {
-                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float x0 = src[0];
-                        const float x1 = src[1];
-
-                        dst_data[0] = x0*cos_theta - x1*sin_theta;
-                        dst_data[1] = x0*sin_theta + x1*cos_theta;
-                    }
-                }
-
-                if (is_vision) {
-                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                        const int64_t ic = i0/2;
-
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                        const float x0 = src[0];
-                        const float x1 = src[n_dims];
-
-                        dst_data[0]      = x0*cos_theta - x1*sin_theta;
-                        dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
-                    }
-                } else {
-                    // fill the remain channels with data from src tensor
-                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        dst_data[0] = src[0];
-                        dst_data[1] = src[1];
-                    }
-                }
-            }
-        }
-    }
-}
-
-// TODO: deduplicate f16/f32 code
-static void ggml_compute_forward_rope_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const bool forward) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * src2 = dst->src[2];
-
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-    int sections[4];
-
-    //const int n_past     = ((int32_t *) dst->op_params)[0];
-    const int n_dims     = ((int32_t *) dst->op_params)[1];
-    const int mode       = ((int32_t *) dst->op_params)[2];
-    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
-    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-    memcpy(§ions,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
-
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
-    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
-    GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(dst);
-
-    GGML_ASSERT(n_dims <= ne0);
-    GGML_ASSERT(n_dims % 2 == 0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    // row index used to determine which thread to use
-    int ir = 0;
-
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
-    float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
-    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
-    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
-
-    if (is_mrope) {
-        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
-    }
-
-    if (is_vision) {
-        GGML_ASSERT(n_dims == ne0/2);
-    }
-
-    const float * freq_factors = NULL;
-    if (src2 != NULL) {
-        GGML_ASSERT(src2->type == GGML_TYPE_F32);
-        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
-        freq_factors = (const float *) src2->data;
-    }
-
-    // backward process uses inverse rotation by cos and sin.
-    // cos and sin build a rotation matrix, where the inverse is the transpose.
-    // this essentially just switches the sign of sin.
-    const float sin_sign = forward ? 1.0f : -1.0f;
-
-    const int32_t * pos = (const int32_t *) src1->data;
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-
-            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            if (!is_mrope) {
-                const int64_t p = pos[i2];
-                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
-            else {
-                const int64_t p_t = pos[i2];
-                const int64_t p_h = pos[i2 + ne2];
-                const int64_t p_w = pos[i2 + ne2 * 2];
-                const int64_t p_e = pos[i2 + ne2 * 3];
-                ggml_mrope_cache_init(
-                    p_t, p_h, p_w, p_e, sections, is_vision,
-                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
-
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
-                if (ir++ < ir0) continue;
-                if (ir   > ir1) break;
-
-                if (is_neox || is_mrope) {
-                    if (is_vision) {
-                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                            const int64_t ic = i0/2;
-
-                            const float cos_theta = cache[i0 + 0];
-                            const float sin_theta = cache[i0 + 1];
-
-                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                            ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                            const float x0 = GGML_FP16_TO_FP32(src[0]);
-                            const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
-
-                            dst_data[0]      = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                            dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                        }
-                    } else {
-                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                            const int64_t ic = i0/2;
-
-                            const float cos_theta = cache[i0 + 0];
-                            const float sin_theta = cache[i0 + 1];
-
-                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                            ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                            const float x0 = GGML_FP16_TO_FP32(src[0]);
-                            const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
-
-                            dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                            dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                        }
-                    }
-                } else {
-                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float x0 = GGML_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_FP16_TO_FP32(src[1]);
-
-                        dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                    }
-                }
-
-                if (is_vision) {
-                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                        const int64_t ic = i0/2;
-
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                        const float x0 = GGML_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
-
-                        dst_data[0]      = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                    }
-                } else {
-                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        dst_data[0] = src[0];
-                        dst_data[1] = src[1];
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_rope(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_rope_f16(params, dst, true);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_rope_f32(params, dst, true);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_rope_back
-
-static void ggml_compute_forward_rope_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_rope_f16(params, dst, false);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_rope_f32(params, dst, false);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_conv_transpose_1d
-
-static void ggml_compute_forward_conv_transpose_1d_f16_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nk = ne00*ne01*ne02;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (ith == 0) {
-        memset(params->wdata, 0, params->wsize);
-
-        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
-                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        dst_data[i00*ne02 + i02] = src[i00];
-                    }
-                }
-            }
-        }
-
-        // permute source data (src1) from (L x Cin) to (Cin x L)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
-            ggml_fp16_t * dst_data = wdata;
-
-            for (int64_t i11 = 0; i11 < ne11; i11++) {
-                const float * const src = (float *)((char *) src1->data + i11*nb11);
-                for (int64_t i10 = 0; i10 < ne10; i10++) {
-                    dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
-                }
-            }
-        }
-
-        // need to zero dst since we are accumulating into it
-        memset(dst->data, 0, ggml_nbytes(dst));
-    }
-    ggml_barrier(params->threadpool);
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-
-    // total rows in dst
-    const int nr = ne1;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
-    ggml_fp16_t * const wdata_src = wdata + nk;
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
-        for (int i10 = 0; i10 < ne10; i10++) {
-            const int i1n = i10*ne11;
-            for (int i00 = 0; i00 < ne00; i00++) {
-                float v = 0;
-                ggml_vec_dot_f16(ne02, &v, 0,
-                        (ggml_fp16_t *)    wdata_src + i1n, 0,
-                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
-                dst_data[i10*s0 + i00] += v;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_conv_transpose_1d_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nk = ne00*ne01*ne02;
-
-    GGML_ASSERT(nb00 == sizeof(float));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (ith == 0) {
-        memset(params->wdata, 0, params->wsize);
-
-        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
-        {
-            float * const wdata = (float *) params->wdata + 0;
-
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
-                    float * dst_data = wdata + i01*ne00*ne02;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        dst_data[i00*ne02 + i02] = src[i00];
-                    }
-                }
-            }
-        }
-
-        // prepare source data (src1)
-        {
-            float * const wdata = (float *) params->wdata + nk;
-            float * dst_data = wdata;
-
-            for (int64_t i11 = 0; i11 < ne11; i11++) {
-                const float * const src = (float *)((char *) src1->data + i11*nb11);
-                for (int64_t i10 = 0; i10 < ne10; i10++) {
-                    dst_data[i10*ne11 + i11] = src[i10];
-                }
-            }
-        }
-
-        // need to zero dst since we are accumulating into it
-        memset(dst->data, 0, ggml_nbytes(dst));
-    }
-    ggml_barrier(params->threadpool);
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-
-    // total rows in dst
-    const int nr = ne1;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * const wdata     = (float *) params->wdata + 0;
-    float * const wdata_src = wdata + nk;
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        float * wdata_kernel = wdata + i1*ne02*ne00;
-        for (int i10 = 0; i10 < ne10; i10++) {
-            const int i1n = i10*ne11;
-            for (int i00 = 0; i00 < ne00; i00++) {
-                float v = 0;
-                ggml_vec_dot_f32(ne02, &v, 0,
-                        wdata_src + i1n, 0,
-                        wdata_kernel + i00*ne02, 0, 1);
-                dst_data[i10*s0 + i00] += v;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_conv_transpose_1d(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_conv_transpose_1d_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_im2col_f32
-// src0: kernel [OC, IC, KH, KW]
-// src1: image [N, IC, IH, IW]
-// dst:  result [N, OH, OW, IC*KH*KW]
-static void ggml_compute_forward_im2col_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
-    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t N  = is_2D ? ne13 : ne12;
-    const int64_t IC = is_2D ? ne12 : ne11;
-    const int64_t IH = is_2D ? ne11 : 1;
-    const int64_t IW = ne10;
-
-    const int64_t KH = is_2D ? ne01 : 1;
-    const int64_t KW = ne00;
-
-    const int64_t OH = is_2D ? ne2 : 1;
-    const int64_t OW = ne1;
-
-    int ofs0 = is_2D ? nb13 : nb12;
-    int ofs1 = is_2D ? nb12 : nb11;
-
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
-    {
-        float * const wdata = (float *) dst->data;
-
-        for (int64_t in = 0; in < N; in++) {
-            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
-                for (int64_t iow = 0; iow < OW; iow++) {
-                    for (int64_t iic = ith; iic < IC; iic += nth) {
-
-                        // micro kernel
-                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
-
-                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
-                            for (int64_t ikw = 0; ikw < KW; ikw++) {
-                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
-                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
-
-                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
-                                } else {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-
-// ggml_compute_forward_im2col_f16
-// src0: kernel [OC, IC, KH, KW]
-// src1: image [N, IC, IH, IW]
-// dst:  result [N, OH, OW, IC*KH*KW]
-static void ggml_compute_forward_im2col_f16(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F16);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
-    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t N  = is_2D ? ne13 : ne12;
-    const int64_t IC = is_2D ? ne12 : ne11;
-    const int64_t IH = is_2D ? ne11 : 1;
-    const int64_t IW = ne10;
-
-    const int64_t KH = is_2D ? ne01 : 1;
-    const int64_t KW = ne00;
-
-    const int64_t OH = is_2D ? ne2 : 1;
-    const int64_t OW = ne1;
-
-    int ofs0 = is_2D ? nb13 : nb12;
-    int ofs1 = is_2D ? nb12 : nb11;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
-    {
-        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
-
-        for (int64_t in = 0; in < N; in++) {
-            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
-                for (int64_t iow = 0; iow < OW; iow++) {
-                    for (int64_t iic = ith; iic < IC; iic += nth) {
-
-                        // micro kernel
-                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
-
-                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
-                            for (int64_t ikw = 0; ikw < KW; ikw++) {
-                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
-                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
-
-                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
-                                } else {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_im2col(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-    switch (dst->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_im2col_f16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_im2col_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_im2col_back_f32
-
-static void ggml_compute_forward_im2col_back_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
-    const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
-    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t N  = is_2D ? ne3 : ne2;
-    const int64_t IC = is_2D ? ne2 : ne1;
-    const int64_t IH = is_2D ? ne1 : 1;
-    const int64_t IW = ne0;
-
-    const int64_t KH = is_2D ? ne11 : 1;
-    const int64_t KW = ne10;
-
-    const int64_t OH = is_2D ? ne02 : 1;
-    const int64_t OW = ne01;
-
-    int ofs0 = is_2D ? nb3 : nb2;
-    int ofs1 = is_2D ? nb2 : nb1;
-
-    GGML_ASSERT(nb0  == sizeof(float));
-
-    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
-    {
-        float * const wdata = (float *) dst->data;
-
-        for (int64_t in = 0; in < N; in++) {
-            for (int64_t iic = ith; iic < IC; iic += nth) {
-                for (int64_t iih = 0; iih < IH; iih++) {
-                    for (int64_t iiw = 0; iiw < IW; iiw++) {
-
-                        // micro kernel
-                        float grad = 0.0f;
-                        for (int64_t ikh = 0; ikh < KH; ikh++) {
-                            for (int64_t ikw = 0; ikw < KW; ikw++) {
-                                // For s0 > 1 some values were skipped over in the forward pass.
-                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
-                                const int64_t tmpw = (iiw + p0 - ikw*d0);
-                                if (tmpw % s0 != 0) {
-                                    continue;
-                                }
-                                const int64_t iow = tmpw / s0;
-
-                                // Equivalent logic as above except for s1.
-                                int64_t ioh;
-                                if (is_2D) {
-                                    const int64_t tmph = iih + p1 - ikh*d1;
-
-                                    if (tmph % s1 != 0) {
-                                        continue;
-                                    }
-
-                                    ioh = tmph / s1;
-                                } else {
-                                    ioh = 0;
-                                }
-
-                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
-                                    continue;
-                                }
-
-                                const float * const grad_in = (const float *) src0->data
-                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                                grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
-                            }
-                        }
-                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
-                        dst_data[iih*IW + iiw] = grad;
-                    }
-                }
-            }
-        }
-    }
-}
-
-// ggml_compute_forward_conv_transpose_2d
-
-static void ggml_compute_forward_conv_transpose_2d(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nk = ne00*ne01*ne02*ne03;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (ith == 0) {
-        memset(params->wdata, 0, params->wsize);
-
-        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
-                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
-                    for (int64_t i01 = 0; i01 < ne01; i01++) {
-                        for (int64_t i00 = 0; i00 < ne00; i00++) {
-                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
-                        }
-                    }
-                }
-            }
-        }
-
-        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
-            for (int i12 = 0; i12 < ne12; i12++) {
-                for (int i11 = 0; i11 < ne11; i11++) {
-                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
-                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
-                    for (int i10 = 0; i10 < ne10; i10++) {
-                        dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
-                    }
-                }
-            }
-        }
-
-        memset(dst->data, 0, ggml_nbytes(dst));
-    }
-    ggml_barrier(params->threadpool);
-
-    const int32_t stride = ggml_get_op_params_i32(dst, 0);
-
-    // total patches in dst
-    const int np = ne2;
-
-    // patches per thread
-    const int dp = (np + nth - 1)/nth;
-
-    // patch range for this thread
-    const int ip0 = dp*ith;
-    const int ip1 = MIN(ip0 + dp, np);
-
-    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-    ggml_fp16_t * const wdata_src = wdata + nk;
-
-    for (int i2 = ip0; i2 < ip1; i2++) { // Cout
-        float * dst_data = (float *)((char *) dst->data + i2*nb2);
-        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
-        for (int i11 = 0; i11 < ne11; i11++) {
-            for (int i10 = 0; i10 < ne10; i10++) {
-                const int i1n = i11*ne10*ne12 + i10*ne12;
-                for (int i01 = 0; i01 < ne01; i01++) {
-                    for (int i00 = 0; i00 < ne00; i00++) {
-                        float v = 0;
-                        ggml_vec_dot_f16(ne03, &v, 0,
-                                wdata_src + i1n, 0,
-                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
-                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
-                    }
-                }
-            }
-        }
-    }
-}
-
-// ggml_compute_forward_pool_1d_sk_p0
-
-static void ggml_compute_forward_pool_1d_sk_p0(
-        const struct ggml_compute_params * params,
-        const enum ggml_op_pool op,
-        const int k,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src = dst->src[0];
-
-    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    const char * cdata = (const char *)src->data;
-    const char * const data_end = cdata + ggml_nbytes(src);
-    float * drow = (float *)dst->data;
-
-    const int64_t rs = dst->ne[0];
-
-    while (cdata < data_end) {
-        const void * srow = (const void *)cdata;
-        int j = 0;
-        for (int64_t i = 0; i < rs; ++i) {
-            switch (op) {
-                case GGML_OP_POOL_AVG:   drow[i] = 0;        break;
-                case GGML_OP_POOL_MAX:   drow[i] = -FLT_MAX; break;
-                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
-            }
-            for (int ki = 0; ki < k; ++ki) {
-                const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
-                switch (op) {
-                    case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break;
-                    case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break;
-                    case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error");
-                }
-                ++j;
-            }
-            switch (op) {
-                case GGML_OP_POOL_AVG:         drow[i] /= k; break;
-                case GGML_OP_POOL_MAX:                       break;
-                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
-            }
-        }
-
-        cdata += src->nb[1];
-        drow  += rs;
-    }
-}
-
-// ggml_compute_forward_pool_1d
-
-static void ggml_compute_forward_pool_1d(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const int32_t * opts = (const int32_t *)dst->op_params;
-    enum ggml_op_pool op = opts[0];
-    const int k0 = opts[1];
-    const int s0 = opts[2];
-    const int p0 = opts[3];
-    GGML_ASSERT(p0 == 0); // padding not supported
-    GGML_ASSERT(k0 == s0); // only s = k supported
-
-    ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
-}
-
-// ggml_compute_forward_pool_2d
-
-static void ggml_compute_forward_pool_2d(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src = dst->src[0];
-
-    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    const int32_t * opts = (const int32_t *)dst->op_params;
-    enum ggml_op_pool op = opts[0];
-    const int k0 = opts[1];
-    const int k1 = opts[2];
-    const int s0 = opts[3];
-    const int s1 = opts[4];
-    const int p0 = opts[5];
-    const int p1 = opts[6];
-    const char * cdata = (const char*)src->data;
-    const char * const data_end = cdata + ggml_nbytes(src);
-
-    const int64_t px = dst->ne[0];
-    const int64_t py = dst->ne[1];
-    const int64_t pa = px * py;
-
-    float * dplane = (float *)dst->data;
-
-    const int ka = k0 * k1;
-    const int offset0 = -p0;
-    const int offset1 = -p1;
-
-    while (cdata < data_end) {
-        for (int oy = 0; oy < py; ++oy) {
-            float * const drow = dplane + oy * px;
-            for (int ox = 0; ox < px; ++ox) {
-                float * const out =  drow + ox;
-                switch (op) {
-                    case GGML_OP_POOL_AVG:     *out = 0;        break;
-                    case GGML_OP_POOL_MAX:     *out = -FLT_MAX; break;
-                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
-                }
-
-                const int ix = offset0 + ox * s0;
-                const int iy = offset1 + oy * s1;
-
-                for (int ky = 0; ky < k1; ++ky) {
-                    if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
-                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
-                    for (int kx = 0; kx < k0; ++kx) {
-                        int j = ix + kx;
-                        if (j < 0 || j >= src->ne[0]) continue;
-                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
-                        switch (op) {
-                            case GGML_OP_POOL_AVG:                     *out += srow_j; break;
-                            case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break;
-                            case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
-                        }
-                    }
-                }
-                switch (op) {
-                    case GGML_OP_POOL_AVG:           *out /= ka; break;
-                    case GGML_OP_POOL_MAX:                       break;
-                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
-                }
-            }
-        }
-
-        cdata  += src->nb[2];
-        dplane += pa;
-    }
-}
-
-// ggml_compute_forward_pool_2d_back
-
-static void ggml_compute_forward_pool_2d_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src  = dst->src[0];
-    const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
-
-    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    const int32_t * opts = (const int32_t *)dst->op_params;
-    enum ggml_op_pool op = opts[0];
-    const int k0 = opts[1];
-    const int k1 = opts[2];
-    const int s0 = opts[3];
-    const int s1 = opts[4];
-    const int p0 = opts[5];
-    const int p1 = opts[6];
-
-    char       * cdata  = (char       *) dst->data;
-    const char * cdataf = (const char *) dstf->data;
-    const char * const data_end = cdata + ggml_nbytes(dst);
-
-    GGML_ASSERT(params->ith == 0);
-    memset(cdata, 0, ggml_nbytes(dst));
-
-    const int64_t px = src->ne[0];
-    const int64_t py = src->ne[1];
-    const int64_t pa = px * py;
-
-    const float * splane = (const float *) src->data;
-
-    const int ka = k0 * k1;
-    const int offset0 = -p0;
-    const int offset1 = -p1;
-
-    while (cdata < data_end) {
-        for (int oy = 0; oy < py; ++oy) {
-            const float * const srow = splane + oy * px;
-            for (int ox = 0; ox < px; ++ox) {
-                const float grad0 = srow[ox];
-
-                const int ix = offset0 + ox * s0;
-                const int iy = offset1 + oy * s1;
-
-                if (op == GGML_OP_POOL_MAX) {
-                    float maxval = -FLT_MAX;
-                    int kxmax = -1;
-                    int kymax = -1;
-
-                    for (int ky = 0; ky < k1; ++ky) {
-                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
-                            continue;
-                        }
-                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
-                        for (int kx = 0; kx < k0; ++kx) {
-                            int j = ix + kx;
-                            if (j < 0 || j >= dst->ne[0]) {
-                                continue;
-                            }
-
-                            const float val = dst->type == GGML_TYPE_F32 ?
-                                ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
-                            if (val <= maxval) {
-                                continue;
-                            }
-
-                            maxval = val;
-                            kxmax = kx;
-                            kymax = ky;
-                        }
-                    }
-
-                    if (kxmax == -1 || kymax == -1) {
-                        continue;
-                    }
-
-                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
-                    const int j = ix + kxmax;
-                    if (dst->type == GGML_TYPE_F32) {
-                        ((float *) drow)[j] += grad0;
-                    } else {
-                        ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
-                    }
-                } else if (op == GGML_OP_POOL_AVG) {
-                    const float grad = grad0 / ka;
-
-                    for (int ky = 0; ky < k1; ++ky) {
-                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
-                            continue;
-                        }
-                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
-                        for (int kx = 0; kx < k0; ++kx) {
-                            int j = ix + kx;
-                            if (j < 0 || j >= dst->ne[0]) {
-                                continue;
-                            }
-
-                            if (dst->type == GGML_TYPE_F32) {
-                                ((float *) drow)[j] += grad;
-                            } else {
-                                ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad);
-                            }
-                        }
-                    }
-                } else {
-                    GGML_ASSERT(false);
-                }
-            }
-        }
-
-        cdata  += dst->nb[2];
-        cdataf += dst->nb[2];
-        splane += pa;
-    }
-}
-
-// ggml_compute_forward_upscale
-
-static void ggml_compute_forward_upscale_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const float sf0 = (float)ne0/src0->ne[0];
-    const float sf1 = (float)ne1/src0->ne[1];
-    const float sf2 = (float)ne2/src0->ne[2];
-    const float sf3 = (float)ne3/src0->ne[3];
-
-    // TODO: optimize
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        const int64_t i03 = i3 / sf3;
-        for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
-            const int64_t i02 = i2 / sf2;
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
-                const int64_t i01 = i1 / sf1;
-                for (int64_t i0 = 0; i0 < ne0; i0++) {
-                    const int64_t i00 = i0 / sf0;
-
-                    const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                          float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
-
-                    *y = *x;
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_upscale(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_upscale_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-
-// ggml_compute_forward_pad
-
-static void ggml_compute_forward_pad_f32(
-    const struct ggml_compute_params * params,
-          struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    float * dst_ptr = (float *) dst->data;
-
-    // TODO: optimize
-
-    for (int64_t i2 = 0; i2 < ne2; ++i2) {
-        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                for (int64_t i3 = 0; i3 < ne3; ++i3) {
-                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
-
-                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-
-                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
-                        dst_ptr[dst_idx] = *src_ptr;
-                    } else {
-                        dst_ptr[dst_idx] = 0;
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_pad(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_pad_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_pad_reflect_1d
-
-static void ggml_compute_forward_pad_reflect_1d(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int32_t * opts = (const int32_t *) dst->op_params;
-    const int p0 = opts[0];
-    const int p1 = opts[1];
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
-                float * left  = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 +         p0*nb0);
-                float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
-
-                ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
-
-                for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0];   }
-                for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_unpad_f32(
-    const struct ggml_compute_params *params,
-    struct ggml_tensor *dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    float * dst_ptr = (float *) dst->data;
-
-    // TODO: optimize
-
-    for (int64_t i2 = 0; i2 < ne2; ++i2) {
-        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                for (int64_t i3 = 0; i3 < ne3; ++i3) {
-                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
-
-                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-
-                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
-                        dst_ptr[dst_idx] = *src_ptr;
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_unpad(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_unpad_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_arange
-
-static void ggml_compute_forward_arange_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    GGML_ASSERT(dst->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const float start = ggml_get_op_params_f32(dst, 0);
-    const float stop  = ggml_get_op_params_f32(dst, 1);
-    const float step  = ggml_get_op_params_f32(dst, 2);
-
-    const int64_t steps = (int64_t) ceilf((stop - start) / step);
-
-    GGML_ASSERT(ggml_nelements(dst) == steps);
-
-    for (int64_t i = ith; i < steps; i+= nth) {
-        float value = start + step * i;
-        ((float *)dst->data)[i] = value;
-    }
-}
-
-static void ggml_compute_forward_arange(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-    switch (dst->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_arange_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_timestep_embedding_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int dim = ggml_get_op_params_i32(dst, 0);
-    const int max_period = ggml_get_op_params_i32(dst, 1);
-
-    int half = dim / 2;
-
-    for (int64_t i = 0; i < ne00; i++) {
-        float * embed_data = (float *)((char *)  dst->data +  i*nb1);
-        for (int64_t j = ith; j < half; j += nth) {
-            float timestep = ((float *)src0->data)[i];
-            float freq = (float)expf(-logf(max_period) * j / half);
-            float arg = timestep * freq;
-            embed_data[j] = cosf(arg);
-            embed_data[j + half] = sinf(arg);
-        }
-        if (dim % 2 != 0 && ith == 0) {
-            embed_data[dim] = 0.f;
-        }
-    }
-}
-
-static void ggml_compute_forward_timestep_embedding(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_timestep_embedding_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_argsort
-
-static void ggml_compute_forward_argsort_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(nb0 == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t nr = ggml_nrows(src0);
-
-    enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
-
-    for (int64_t i = ith; i < nr; i += nth) {
-        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
-        const float * src_data = (float *)((char *) src0->data + i*nb01);
-
-        for (int64_t j = 0; j < ne0; j++) {
-            dst_data[j] = j;
-        }
-
-        // C doesn't have a functional sort, so we do a bubble sort instead
-        for (int64_t j = 0; j < ne0; j++) {
-            for (int64_t k = j + 1; k < ne0; k++) {
-                if ((order == GGML_SORT_ORDER_ASC  && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
-                    (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
-                    int32_t tmp = dst_data[j];
-                    dst_data[j] = dst_data[k];
-                    dst_data[k] = tmp;
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_argsort(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_argsort_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_flash_attn_ext
-
-static void ggml_compute_forward_flash_attn_ext_f16(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * q,
-        const struct ggml_tensor * k,
-        const struct ggml_tensor * v,
-        const struct ggml_tensor * mask,
-        struct ggml_tensor * dst) {
-
-    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t D = neq0;
-    const int64_t N = neq1;
-
-    GGML_ASSERT(ne0 == D);
-    GGML_ASSERT(ne2 == N);
-
-    // input tensor rows must be contiguous
-    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
-    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
-    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
-
-    GGML_ASSERT(neq0 == D);
-    GGML_ASSERT(nek0 == D);
-    GGML_ASSERT(nev0 == D);
-
-    GGML_ASSERT(neq1 == N);
-    GGML_ASSERT(nev0 == D);
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    // broadcast factors
-    const int64_t rk2 = neq2/nek2;
-    const int64_t rk3 = neq3/nek3;
-
-    const int64_t rv2 = neq2/nev2;
-    const int64_t rv3 = neq3/nev3;
-
-    // parallelize by q rows using ggml_vec_dot_f32
-
-    // total rows in q
-    const int nr = neq1*neq2*neq3;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float scale         = 1.0f;
-    float max_bias      = 0.0f;
-    float logit_softcap = 0.0f;
-
-    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
-    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
-
-    if (logit_softcap != 0) {
-        scale /= logit_softcap;
-    }
-
-    const uint32_t n_head      = neq2;
-    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    enum ggml_type    const k_vec_dot_type = type_traits_cpu[k->type].vec_dot_type;
-    ggml_from_float_t const q_to_vec_dot   = type_traits_cpu[k_vec_dot_type].from_float;
-    ggml_vec_dot_t    const kq_vec_dot     = type_traits_cpu[k->type].vec_dot;
-    ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
-
-    GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
-    GGML_ASSERT(v_to_float   && "fattn: unsupported V-type");
-
-    // loop over n_batch and n_head
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // q indices
-        const int iq3 = ir/(neq2*neq1);
-        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
-        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
-
-        const uint32_t h = iq2; // head index
-        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
-
-        float S = 0.0f;      // sum
-        float M = -INFINITY; // maximum KQ value
-
-        float       * VKQ32 = (float       *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
-        float       * V32   =                 (VKQ32 + 1*D); // (temporary) FP32 V buffer
-        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
-        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
-
-        if (v->type == GGML_TYPE_F16) {
-            memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
-        } else {
-            memset(VKQ32, 0, D*sizeof(float));
-        }
-
-        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
-
-        // k indices
-        const int ik3 = iq3 / rk3;
-        const int ik2 = iq2 / rk2;
-
-        // v indices
-        const int iv3 = iq3 / rv3;
-        const int iv2 = iq2 / rv2;
-
-        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
-        q_to_vec_dot(pq, Q_q, D);
-
-        // online softmax / attention
-        // loop over n_kv and n_head_kv
-        // ref: https://arxiv.org/pdf/2112.05682.pdf
-        for (int64_t ic = 0; ic < nek1; ++ic) {
-            const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
-            if (mv == -INFINITY) {
-                continue;
-            }
-
-            float s; // KQ value
-
-            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
-            kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
-
-            s = s*scale; // scale KQ value
-
-            if (logit_softcap != 0.0f) {
-                s = logit_softcap*tanhf(s);
-            }
-
-            s += mv; // apply mask
-
-            const float Mold = M;
-
-            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
-            float vs = 1.0f; // post-softmax KQ value, expf(s - M)
-
-            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
-
-            if (v->type == GGML_TYPE_F16) {
-                if (s > M) {
-                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
-                    M = s;
-                    ms = expf(Mold - M);
-
-                    // V = V*expf(Mold - M)
-                    ggml_vec_scale_f16(D, VKQ16, ms);
-                } else {
-                    // no new maximum, ms == 1.0f, vs != 1.0f
-                    vs = expf(s - M);
-                }
-
-                // V += v*expf(s - M)
-                ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
-            } else {
-                if (s > M) {
-                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
-                    M = s;
-                    ms = expf(Mold - M);
-
-                    // V = V*expf(Mold - M)
-                    ggml_vec_scale_f32(D, VKQ32, ms);
-                } else {
-                    // no new maximum, ms == 1.0f, vs != 1.0f
-                    vs = expf(s - M);
-                }
-
-                v_to_float(v_data, V32, D);
-
-                // V += v*expf(s - M)
-                ggml_vec_mad_f32(D, VKQ32, V32, vs);
-            }
-
-            S = S*ms + vs; // scale and increment sum with partial sum
-        }
-
-        if (v->type == GGML_TYPE_F16) {
-            for (int64_t d = 0; d < D; ++d) {
-                VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
-            }
-        }
-
-        // V /= S
-        const float S_inv = 1.0f/S;
-        ggml_vec_scale_f32(D, VKQ32, S_inv);
-
-        // dst indices
-        const int i1 = iq1;
-        const int i2 = iq2;
-        const int i3 = iq3;
-
-        // original
-        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
-
-        // permute(0, 2, 1, 3)
-        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
-    }
-}
-
-static void ggml_compute_forward_flash_attn_ext(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * q,
-        const struct ggml_tensor * k,
-        const struct ggml_tensor * v,
-        const struct ggml_tensor * mask,
-        struct ggml_tensor * dst) {
-    switch (dst->op_params[3]) {
-        case GGML_PREC_DEFAULT:
-        case GGML_PREC_F32:
-            {
-                // uses F32 accumulators
-                ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_flash_attn_back
-
-static void ggml_compute_forward_flash_attn_back_f32(
-        const struct ggml_compute_params * params,
-        const bool masked,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * q = dst->src[0];
-    const struct ggml_tensor * k = dst->src[1];
-    const struct ggml_tensor * v = dst->src[2];
-    const struct ggml_tensor * d = dst->src[3];
-
-    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
-    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t D = neq0;
-    const int64_t N = neq1;
-    const int64_t P = nek1 - N;
-    const int64_t M = P + N;
-
-    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
-    const int mxDM = MAX(D, Mup);
-
-    // GGML_ASSERT(ne0 == D);
-    // GGML_ASSERT(ne1 == N);
-    GGML_ASSERT(P >= 0);
-
-    GGML_ASSERT(nbq0 == sizeof(float));
-    GGML_ASSERT(nbk0 == sizeof(float));
-    GGML_ASSERT(nbv0 == sizeof(float));
-
-    GGML_ASSERT(neq0 == D);
-    GGML_ASSERT(nek0 == D);
-    GGML_ASSERT(nev1 == D);
-    GGML_ASSERT(ned0 == D);
-
-    GGML_ASSERT(neq1 == N);
-    GGML_ASSERT(nek1 == N + P);
-    GGML_ASSERT(nev1 == D);
-    GGML_ASSERT(ned1 == N);
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    if (ith == 0) {
-        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
-    }
-    ggml_barrier(params->threadpool);
-
-    const int64_t elem_q = ggml_nelements(q);
-    const int64_t elem_k = ggml_nelements(k);
-
-    enum ggml_type result_type = dst->type;
-    GGML_ASSERT(ggml_blck_size(result_type) == 1);
-    const size_t tsize = ggml_type_size(result_type);
-
-    const size_t offs_q = 0;
-    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
-    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
-
-    void * grad_q = (char *) dst->data;
-    void * grad_k = (char *) dst->data + offs_k;
-    void * grad_v = (char *) dst->data + offs_v;
-
-    const size_t nbgq1 = nb0*neq0;
-    const size_t nbgq2 = nb0*neq0*neq1;
-    const size_t nbgq3 = nb0*neq0*neq1*neq2;
-
-    const size_t nbgk1 = nb0*nek0;
-    const size_t nbgk2 = nb0*nek0*nek1;
-    const size_t nbgk3 = nb0*nek0*nek1*neq2;
-
-    const size_t nbgv1 = nb0*nev0;
-    const size_t nbgv2 = nb0*nev0*nev1;
-    const size_t nbgv3 = nb0*nev0*nev1*neq2;
-
-    // parallelize by k rows using ggml_vec_dot_f32
-
-    // total rows in k
-    const int nr = nek2*nek3;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    const float scale = 1.0f/sqrtf(D);
-
-    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
-
-    // how often k2 (and v2) is repeated in q2
-    int nrep = neq2/nek2;
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // q indices
-        const int ik3 = ir/(nek2);
-        const int ik2 = ir - ik3*nek2;
-
-        const int iq3 = ik3;
-        const int id3 = ik3;
-        const int iv3 = ik3;
-        const int iv2 = ik2;
-
-        for (int irep = 0; irep < nrep; ++irep) {
-            const int iq2 = ik2 + irep*nek2;
-            const int id2 = iq2;
-
-            // (ik2 + irep*nek2) % nek2 == ik2
-            for (int iq1 = 0; iq1 < neq1; ++iq1) {
-                const int id1 = iq1;
-
-                // not sure about CACHE_LINE_SIZE_F32..
-                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
-                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
-                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
-
-                for (int i = M; i < Mup; ++i) {
-                    S[i] = -INFINITY;
-                }
-
-                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
-                for (int64_t ic = 0; ic < masked_begin; ++ic) {
-                    // k indices
-                    const int ik1 = ic;
-
-                    // S indices
-                    const int i1 = ik1;
-
-                    ggml_vec_dot_f32(neq0,
-                            S + i1, 0,
-                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
-                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
-                }
-
-                // scale
-                ggml_vec_scale_f32(masked_begin, S, scale);
-
-                for (int64_t i = masked_begin; i < M; i++) {
-                    S[i] = -INFINITY;
-                }
-
-                // softmax
-                // exclude known -INF S[..] values from max and loop
-                // dont forget to set their SM values to zero
-                {
-                    float max = -INFINITY;
-                    ggml_vec_max_f32(masked_begin, &max, S);
-
-                    ggml_float sum = 0.0;
-                    {
-#ifdef GGML_SOFT_MAX_ACCELERATE
-                        max = -max;
-                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
-                        vvexpf(SM, SM, &Mup);
-                        ggml_vec_sum_f32(Mup, &sum, SM);
-#else
-                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
-#endif
-                    }
-
-                    assert(sum > 0.0);
-
-                    sum = 1.0/sum;
-                    ggml_vec_scale_f32(masked_begin, SM, sum);
-
-                }
-
-                // step-by-step explanation
-                {
-                    // forward-process                    shape      grads from backward process
-                    // parallel_for ik2,ik3:
-                    //  for irep:
-                    //   iq2 = ik2 + irep*nek2
-                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]
-                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
-                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]
-                    //   for iq1:
-                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
-                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
-                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
-                    //    S0     = -Inf                   [D,1,1,1]
-                    //   ~S1[i]  = dot(kcur[:D,i], qcur)
-                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
-                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
-                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
-                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
-                    //   ~S5[i]  = dot(vcur[:,i], S4)
-                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]
-                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
-                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
-                    // dst                               backward-/ grad[dst]                 = d
-                    //
-                    // output gradients with their dependencies:
-                    //
-                    // grad[kcur] = grad[S1].T @ qcur
-                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
-                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
-                    // grad[S4]   = grad[S5] @ vcur
-                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
-                    // grad[qcur] = grad[S1]   @ kcur
-                    // grad[vcur] = grad[S5].T @ S4
-                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
-                    //
-                    // in post-order:
-                    //
-                    // S1         = qcur @ kcur.T
-                    // S2         = S1 * scale
-                    // S3         = diag_mask_inf(S2, P)
-                    // S4         = softmax(S3)
-                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
-                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
-                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
-                    // grad[qcur] = grad[S1]   @ kcur
-                    // grad[kcur] = grad[S1].T @ qcur
-                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
-                    //
-                    // using less variables (SM=S4):
-                    //
-                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
-                    // SM            = softmax(S)
-                    // S             = d[:D,iq1,iq2,iq3] @ vcur
-                    // dot_SM_gradSM = dot(SM, S)
-                    // S             = SM * (S - dot(SM, S))
-                    // S             = diag_mask_zero(S, P) * scale
-                    //
-                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
-                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur
-                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM
-                }
-
-                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
-                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
-                // for ic:
-                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
-                // exclude known future zero S[..] values from operation
-                ggml_vec_set_f32(masked_begin, S, 0);
-                for (int64_t ic = 0; ic < D; ++ic) {
-                    ggml_vec_mad_f32(masked_begin,
-                            S,
-                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),
-                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
-                }
-
-                // S = SM * (S - dot(SM, S))
-                float dot_SM_gradSM = 0;
-                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
-                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
-                ggml_vec_mul_f32 (masked_begin, S, S, SM);
-
-                // S = diag_mask_zero(S, P) * scale
-                // already done by above ggml_vec_set_f32
-
-                // exclude known zero S[..] values from operation
-                ggml_vec_scale_f32(masked_begin, S, scale);
-
-                // S    shape [M,1]
-                // SM   shape [M,1]
-                // kcur shape [D,M]
-                // qcur shape [D,1]
-                // vcur shape [M,D]
-
-                // grad[q][:D,iq1,iq2,iq3] += S @ kcur
-                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
-                // for ic:
-                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
-                // exclude known zero S[..] values from loop
-                for (int64_t ic = 0; ic < masked_begin; ++ic) {
-                    ggml_vec_mad_f32(D,
-                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),
-                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),
-                            S[ic]);
-                }
-
-                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
-                // for ic:
-                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
-                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
-                // exclude known zero S[..] values from loop
-                for (int64_t ic = 0; ic < masked_begin; ++ic) {
-                    ggml_vec_mad_f32(D,
-                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),
-                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),
-                            S[ic]);
-                }
-
-                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM
-                // for ic:
-                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
-                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]
-                // exclude known zero SM[..] values from mad
-                for (int64_t ic = 0; ic < D; ++ic) {
-                    ggml_vec_mad_f32(masked_begin,
-                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
-                            SM,
-                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_flash_attn_back(
-        const struct ggml_compute_params * params,
-        const bool masked,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * q = dst->src[0];
-
-    switch (q->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_ssm_conv
-
-static void ggml_compute_forward_ssm_conv_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0]; // conv_x
-    const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc  = src1->ne[0]; // d_conv
-    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
-    const int nr  = src0->ne[1]; // d_inner
-    const int n_t =  dst->ne[1]; // tokens per sequence
-    const int n_s =  dst->ne[2]; // number of sequences in the batch
-
-    GGML_ASSERT( dst->ne[0] == nr);
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT(src1->nb[0] == sizeof(float));
-    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-    const int ir  = ir1 - ir0;
-
-    for (int i3 = 0; i3 < n_s; ++i3) {
-        for (int i2 = 0; i2 < n_t; ++i2) {
-            // {d_conv - 1 + n_t, d_inner, n_seqs}
-            // sliding window
-            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
-            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
-            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
-
-            // TODO: transpose the output for smaller strides for big batches?
-            // d_inner
-            for (int i1 = 0; i1 < ir; ++i1) {
-                // rowwise dot product
-                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
-                float sumf = 0.0f;
-
-                // d_conv
-                for (int i0 = 0; i0 < nc; ++i0) {
-                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
-                }
-                x[i1] = sumf;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_ssm_conv(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    switch (dst->src[0]->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_ssm_conv_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_ssm_scan
-
-static void ggml_compute_forward_ssm_scan_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0]; // s
-    const struct ggml_tensor * src1 = dst->src[1]; // x
-    const struct ggml_tensor * src2 = dst->src[2]; // dt
-    const struct ggml_tensor * src3 = dst->src[3]; // A
-    const struct ggml_tensor * src4 = dst->src[4]; // B
-    const struct ggml_tensor * src5 = dst->src[5]; // C
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t nc  = src0->ne[0]; // d_state
-    const int64_t nr  = src0->ne[1]; // d_inner
-    const int64_t n_t = src1->ne[1]; // number of tokens per sequence
-    const int64_t n_s = src0->ne[2]; // number of sequences in the batch
-
-    GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT(src1->nb[0] == sizeof(float));
-    GGML_ASSERT(src2->nb[0] == sizeof(float));
-    GGML_ASSERT(src3->nb[0] == sizeof(float));
-    GGML_ASSERT(src4->nb[0] == sizeof(float));
-    GGML_ASSERT(src5->nb[0] == sizeof(float));
-    // required for the dot product between s and C
-    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
-    // required for per-sequence offsets for states
-    GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
-    // required to get correct offset for state destination (i.e. src1->nb[3])
-    GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-    const int ir  = ir1 - ir0;
-
-    for (int i3 = 0; i3 < n_s; ++i3) {
-        for (int i2 = 0; i2 < n_t; ++i2) {
-            const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
-            const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-            const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
-            const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
-            const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
-            const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
-                  float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-                  float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
-
-            // use the output as the source for the next token-wise iterations
-            if (i2 > 0) { s0 = s; }
-
-            // d_inner
-            for (int i1 = 0; i1 < ir; ++i1) {
-                // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
-                float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
-                float x_dt = x[i1] * dt_soft_plus;
-                float sumf = 0.0f;
-                // d_state
-                for (int i0 = 0; i0 < nc; ++i0) {
-                    int i = i0 + i1*nc;
-                    // state = prev_state * dA + dB * x
-                    float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
-                    // y = rowwise_dotprod(state, C)
-                    sumf += state * C[i0];
-                    s[i] = state;
-                }
-                y[i1] = sumf;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_ssm_scan(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    switch (dst->src[0]->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_ssm_scan_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_win_part
-
-static void ggml_compute_forward_win_part_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    UNUSED(params);
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
-
-    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
-    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
-    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
-
-    assert(ne00 == ne0);
-    assert(ne3  == nep0*nep1);
-
-    // TODO: optimize / multi-thread
-    for (int py = 0; py < nep1; ++py) {
-        for (int px = 0; px < nep0; ++px) {
-            const int64_t i3 = py*nep0 + px;
-            for (int64_t i2 = 0; i2 < ne2; ++i2) {
-                for (int64_t i1 = 0; i1 < ne1; ++i1) {
-                    for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                        const int64_t i02 = py*w + i2;
-                        const int64_t i01 = px*w + i1;
-                        const int64_t i00 = i0;
-
-                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;
-                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;
-
-                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
-                            ((float *) dst->data)[i] = 0.0f;
-                        } else {
-                            ((float *) dst->data)[i] = ((float *) src0->data)[j];
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_win_part(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_win_part_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_win_unpart
-
-static void ggml_compute_forward_win_unpart_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    UNUSED(params);
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
-
-    const int32_t w = ((const int32_t *)(dst->op_params))[0];
-
-    // padding
-    const int px = (w - ne1%w)%w;
-    //const int py = (w - ne2%w)%w;
-
-    const int npx = (px + ne1)/w;
-    //const int npy = (py + ne2)/w;
-
-    assert(ne0 == ne00);
-
-    // TODO: optimize / multi-thread
-    for (int64_t i2 = 0; i2 < ne2; ++i2) {
-        for (int64_t i1 = 0; i1 < ne1; ++i1) {
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                const int ip2 = i2/w;
-                const int ip1 = i1/w;
-
-                const int64_t i02 = i2%w;
-                const int64_t i01 = i1%w;
-                const int64_t i00 = i0;
-
-                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
-                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;
-
-                ((float *) dst->data)[j] = ((float *) src0->data)[i];
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_win_unpart(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_win_unpart_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-//gmml_compute_forward_unary
-
-static void ggml_compute_forward_unary(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const enum ggml_unary_op op = ggml_get_unary_op(dst);
-
-    switch (op) {
-        case GGML_UNARY_OP_ABS:
-            {
-                ggml_compute_forward_abs(params, dst);
-            } break;
-        case GGML_UNARY_OP_SGN:
-            {
-                ggml_compute_forward_sgn(params, dst);
-            } break;
-        case GGML_UNARY_OP_NEG:
-            {
-                ggml_compute_forward_neg(params, dst);
-            } break;
-        case GGML_UNARY_OP_STEP:
-            {
-                ggml_compute_forward_step(params, dst);
-            } break;
-        case GGML_UNARY_OP_TANH:
-            {
-                ggml_compute_forward_tanh(params, dst);
-            } break;
-        case GGML_UNARY_OP_ELU:
-            {
-                ggml_compute_forward_elu(params, dst);
-            } break;
-        case GGML_UNARY_OP_RELU:
-            {
-                ggml_compute_forward_relu(params, dst);
-            } break;
-        case GGML_UNARY_OP_SIGMOID:
-            {
-                ggml_compute_forward_sigmoid(params, dst);
-            } break;
-        case GGML_UNARY_OP_GELU:
-            {
-                ggml_compute_forward_gelu(params, dst);
-            } break;
-        case GGML_UNARY_OP_GELU_QUICK:
-            {
-                ggml_compute_forward_gelu_quick(params, dst);
-            } break;
-        case GGML_UNARY_OP_SILU:
-            {
-                ggml_compute_forward_silu(params, dst);
-            } break;
-        case GGML_UNARY_OP_HARDSWISH:
-            {
-                ggml_compute_forward_hardswish(params, dst);
-            } break;
-        case GGML_UNARY_OP_HARDSIGMOID:
-            {
-                ggml_compute_forward_hardsigmoid(params, dst);
-            } break;
-        case GGML_UNARY_OP_EXP:
-            {
-                ggml_compute_forward_exp(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_get_rel_pos
-
-static void ggml_compute_forward_get_rel_pos_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    UNUSED(params);
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int64_t w = ne1;
-
-    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
-    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;
-
-    for (int64_t i2 = 0; i2 < ne2; ++i2) {
-        for (int64_t i1 = 0; i1 < ne1; ++i1) {
-            const int64_t pos = (w - i1 - 1) + i2;
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_get_rel_pos(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-            {
-                ggml_compute_forward_get_rel_pos_f16(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_add_rel_pos
-
-static void ggml_compute_forward_add_rel_pos_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * src2 = dst->src[2];
-
-    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
-    if (!inplace) {
-        if (params->ith == 0) {
-            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
-        }
-        ggml_barrier(params->threadpool);
-    }
-    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
-
-    float * src1_data = (float *) src1->data;
-    float * src2_data = (float *) src2->data;
-    float * dst_data  = (float *) dst->data;
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // total patches in dst
-    const int np = ne13;
-
-    // patches per thread
-    const int dp = (np + nth - 1)/nth;
-
-    // patch range for this thread
-    const int ip0 = dp*ith;
-    const int ip1 = MIN(ip0 + dp, np);
-
-    for (int64_t i13 = ip0; i13 < ip1; ++i13) {
-        for (int64_t i12 = 0; i12 < ne12; ++i12) {
-            for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
-                for (int64_t i10 = 0; i10 < ne10; ++i10) {
-                    const int64_t jp0  = jp1 + i10;
-                    const float src1_e = src1_data[jp0];
-                    const float src2_e = src2_data[jp0];
-
-                    const int64_t jdh = jp0 * ne10;
-                    const int64_t jdw = jdh - (ne10 - 1) * i10;
-
-                    for (int64_t j = 0; j < ne10; ++j) {
-                        dst_data[jdh + j     ] += src2_e;
-                        dst_data[jdw + j*ne10] += src1_e;
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_add_rel_pos(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_add_rel_pos_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_rwkv_wkv6
-
-static void ggml_compute_forward_rwkv_wkv6_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    const int64_t T = dst->src[1]->ne[2];
-    const int64_t C = dst->ne[0];
-    const int64_t HEADS = dst->src[1]->ne[1];
-    const int64_t n_seqs = dst->src[5]->ne[1];
-    const int64_t head_size = C / HEADS;
-
-    float * dst_data = (float *) dst->data;
-    float * state = ((float *) dst->data) + C * T;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    if (ith >= HEADS) {
-        return;
-    }
-
-    const int h_start = (HEADS * ith) / nth;
-    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
-                (HEADS * (ith + 1)) / nth : HEADS;
-
-    float * k =          (float *) dst->src[0]->data;
-    float * v =          (float *) dst->src[1]->data;
-    float * r =          (float *) dst->src[2]->data;
-    float * time_faaaa = (float *) dst->src[3]->data;
-    float * time_decay = (float *) dst->src[4]->data;
-
-    size_t t_stride = HEADS * head_size; // Same to C
-
-    size_t h_stride = C / HEADS;
-    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
-    size_t h_stride_2d = head_size * head_size;
-
-    if (ith == 0) {
-        memset(dst_data, 0, T * C * sizeof(float));
-    }
-    ggml_barrier(params->threadpool);
-
-
-    #if defined(__AVX__) && !defined(__AVX512F__)
-        #define GGML_F32X GGML_F32x8
-        #define GGML_F32X_SET1 GGML_F32x8_SET1
-        #define GGML_F32X_LOAD GGML_F32x8_LOAD
-        #define GGML_F32X_STORE GGML_F32x8_STORE
-        #define GGML_F32X_MUL GGML_F32x8_MUL
-        #define GGML_F32X_FMA GGML_F32x8_FMA
-        #define WKV_VECTOR_SIZE 8
-    #elif defined(__AVX512F__)
-        #define GGML_F32X GGML_F32x16
-        #define GGML_F32X_SET1 GGML_F32x16_SET1
-        #define GGML_F32X_LOAD GGML_F32x16_LOAD
-        #define GGML_F32X_STORE GGML_F32x16_STORE
-        #define GGML_F32X_MUL GGML_F32x16_MUL
-        #define GGML_F32X_FMA GGML_F32x16_FMA
-        #define WKV_VECTOR_SIZE 16
-    #elif defined(__ARM_NEON) && defined(__aarch64__)
-        #define GGML_F32X GGML_F32x4
-        #define GGML_F32X_SET1 GGML_F32x4_SET1
-        #define GGML_F32X_LOAD GGML_F32x4_LOAD
-        #define GGML_F32X_STORE GGML_F32x4_STORE
-        #define GGML_F32X_MUL GGML_F32x4_MUL
-        #define GGML_F32X_FMA GGML_F32x4_FMA
-        #define WKV_VECTOR_SIZE 4
-    #endif
-
-    #ifdef WKV_VECTOR_SIZE
-        const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
-
-        for (int64_t t = 0; t < T; t++) {
-            size_t t_offset = t * t_stride;
-            size_t state_offset = head_size * C * (t / (T / n_seqs));
-            float * state_cur = state + state_offset;
-            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
-
-            for (int64_t h = h_start; h < h_end; h++) {
-                size_t h_offset = h * h_stride;
-                size_t t_h_offset = t_offset + h_offset;
-                size_t h_2d_offset = h * h_stride_2d;
-
-                for (int64_t i = 0; i < head_size; i++) {
-                    size_t t_h_i_offset = t_h_offset + i;
-                    size_t h_i_offset = h_offset + i;
-                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
-
-                    float k_val = k[t_h_i_offset];
-                    float r_val = r[t_h_i_offset];
-                    float time_faaaa_val = time_faaaa[h_i_offset];
-                    float time_decay_val = time_decay[t_h_i_offset];
-
-                    // Broadcast scalar values to vectors
-                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
-                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);
-                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
-                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
-
-                    for (int64_t j = 0; j < vec_count; j++) {
-                        size_t base_j = j * WKV_VECTOR_SIZE;
-                        size_t t_h_j_offset = t_h_offset + base_j;
-                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
-
-                        // Load x elements at once
-                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
-                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
-                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
-
-                        // Compute kv = v * k
-                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
-
-                        // Compute temp = kv * time_faaaa + prev_state
-                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
-
-                        // Update dst: dst += temp * r
-                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
-                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
-
-                        // Update state: state = prev_state * time_decay + kv
-                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
-                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
-                    }
-
-                    // Handle remaining elements, this will not be used.
-                    for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
-                        size_t t_h_j_offset = t_h_offset + j;
-                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
-                        float v_val = v[t_h_j_offset];
-                        float kv_val = v_val * k_val;
-                        float prev_state_val = state_prev[h_2d_i_j_offset];
-                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
-                        dst_data[t_h_j_offset] += temp_val * r_val;
-                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
-                    }
-                }
-            }
-        }
-
-    #else
-        // basically fused operations:
-        // dst = r @ (time_faaaa * (k @ v) + state),
-        // state = time_decay * state + (k @ v),
-        // recursive through each token
-        for (int64_t t = 0; t < T; t++) {
-            size_t t_offset = t * t_stride;
-            size_t state_offset = head_size * C * (t / (T / n_seqs));
-            float * state_cur = state + state_offset;
-            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
-
-            for (int64_t h = h_start; h < h_end; h++) {
-                size_t h_offset = h * h_stride;
-                size_t t_h_offset = t_offset + h_offset;
-                size_t h_2d_offset = h * h_stride_2d;
-
-                for (int64_t i = 0; i < head_size; i++) {
-                    size_t t_h_i_offset = t_h_offset + i;
-                    size_t h_i_offset = h_offset + i;
-                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
-
-                    float k_val = k[t_h_i_offset];
-                    float r_val = r[t_h_i_offset];
-                    float time_faaaa_val = time_faaaa[h_i_offset];
-                    // RWKV v6: different time_decay for each token.
-                    float time_decay_val = time_decay[t_h_i_offset];
-
-                    for (int64_t j = 0; j < head_size; j++) {
-                        size_t t_h_j_offset = t_h_offset + j;
-                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
-
-                        float v_val = v[t_h_j_offset];
-                        float kv_val = v_val * k_val;
-                        float prev_state_val = state_prev[h_2d_i_j_offset];
-                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
-                        dst_data[t_h_j_offset] += temp_val * r_val;
-                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
-                    }
-                }
-            }
-        }
-    #endif
-}
-
-
-static void ggml_compute_forward_rwkv_wkv6(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_rwkv_wkv6_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_gla
-
-static void ggml_compute_forward_gla_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    const int64_t T = dst->src[1]->ne[2];
-    const int64_t C = dst->ne[0];
-    const int64_t HEADS = dst->src[1]->ne[1];
-    const int64_t n_seqs = dst->src[4]->ne[1];
-    const int64_t head_size = C / HEADS;
-    const float scale = ggml_get_op_params_f32(dst, 0);
-
-    float * dst_data = (float *) dst->data;
-    float * state = ((float *) dst->data) + C * T;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    if (ith >= HEADS) {
-        return;
-    }
-
-    const int h_start = (HEADS * ith) / nth;
-    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
-                (HEADS * (ith + 1)) / nth : HEADS;
-
-    float * k = (float *) dst->src[0]->data;
-    float * v = (float *) dst->src[1]->data;
-    float * q = (float *) dst->src[2]->data;
-    float * g = (float *) dst->src[3]->data;
-
-    size_t t_stride = HEADS * head_size; // Same to C
-
-    size_t h_stride = C / HEADS;
-    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
-    size_t h_stride_2d = head_size * head_size;
-
-    if (ith == 0) {
-        memset(dst_data, 0, T * C * sizeof(float));
-    }
-    ggml_barrier(params->threadpool);
-
-
-    #if defined(__AVX__) && !defined(__AVX512F__)
-        #define GGML_F32X GGML_F32x8
-        #define GGML_F32X_SET1 GGML_F32x8_SET1
-        #define GGML_F32X_LOAD GGML_F32x8_LOAD
-        #define GGML_F32X_STORE GGML_F32x8_STORE
-        #define GGML_F32X_MUL GGML_F32x8_MUL
-        #define GGML_F32X_FMA GGML_F32x8_FMA
-        #define GLA_VECTOR_SIZE 8
-    #elif defined(__AVX512F__)
-        #define GGML_F32X GGML_F32x16
-        #define GGML_F32X_SET1 GGML_F32x16_SET1
-        #define GGML_F32X_LOAD GGML_F32x16_LOAD
-        #define GGML_F32X_STORE GGML_F32x16_STORE
-        #define GGML_F32X_MUL GGML_F32x16_MUL
-        #define GGML_F32X_FMA GGML_F32x16_FMA
-        #define GLA_VECTOR_SIZE 16
-    #elif defined(__ARM_NEON) && defined(__aarch64__)
-        #define GGML_F32X GGML_F32x4
-        #define GGML_F32X_SET1 GGML_F32x4_SET1
-        #define GGML_F32X_LOAD GGML_F32x4_LOAD
-        #define GGML_F32X_STORE GGML_F32x4_STORE
-        #define GGML_F32X_MUL GGML_F32x4_MUL
-        #define GGML_F32X_FMA GGML_F32x4_FMA
-        #define GLA_VECTOR_SIZE 4
-    #endif
-
-    #ifdef GLA_VECTOR_SIZE
-        const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
-
-        for (int64_t t = 0; t < T; t++) {
-            size_t t_offset = t * t_stride;
-            size_t state_offset = head_size * C * (t / (T / n_seqs));
-            float * state_cur = state + state_offset;
-            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
-
-            for (int64_t h = h_start; h < h_end; h++) {
-                size_t h_offset = h * h_stride;
-                size_t t_h_offset = t_offset + h_offset;
-                size_t h_2d_offset = h * h_stride_2d;
-
-                for (int64_t i = 0; i < head_size; i++) {
-                    size_t t_h_i_offset = t_h_offset + i;
-                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
-
-                    float k_val = k[t_h_i_offset];
-                    float q_val = q[t_h_i_offset] * scale;
-                    float g_val = g[t_h_i_offset];
-
-                    // Broadcast scalar values to vectors
-                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
-                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
-                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
-
-                    for (int64_t j = 0; j < vec_count; j++) {
-                        size_t base_j = j * GLA_VECTOR_SIZE;
-                        size_t t_h_j_offset = t_h_offset + base_j;
-                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
-
-                        // Load x elements at once
-                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
-                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
-                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
-
-                        // Compute kv = v * k
-                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
-
-                        // Compute temp = prev_state * g + kv
-                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
-
-                        // Update dst: dst += temp * q
-                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
-                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
-
-                        // Update state
-                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
-                    }
-
-                    // Handle remaining elements, this will not be used.
-                    for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
-                        size_t t_h_j_offset = t_h_offset + j;
-                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
-                        float v_val = v[t_h_j_offset];
-                        float kv_val = v_val * k_val;
-                        float prev_state_val = state_prev[h_2d_i_j_offset];
-                        float temp_val = kv_val + prev_state_val * g_val;
-                        dst_data[t_h_j_offset] += temp_val * q_val;
-                        state_cur[h_2d_i_j_offset] = temp_val;
-                    }
-                }
-            }
-        }
-
-    #else
-        for (int64_t t = 0; t < T; t++) {
-            size_t t_offset = t * t_stride;
-            size_t state_offset = head_size * C * (t / (T / n_seqs));
-            float * state_cur = state + state_offset;
-            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
-
-            for (int64_t h = h_start; h < h_end; h++) {
-                size_t h_offset = h * h_stride;
-                size_t t_h_offset = t_offset + h_offset;
-                size_t h_2d_offset = h * h_stride_2d;
-
-                for (int64_t i = 0; i < head_size; i++) {
-                    size_t t_h_i_offset = t_h_offset + i;
-                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
-
-                    float k_val = k[t_h_i_offset];
-                    float q_val = q[t_h_i_offset] * scale;
-                    float g_val = g[t_h_i_offset];
-
-                    for (int64_t j = 0; j < head_size; j++) {
-                        size_t t_h_j_offset = t_h_offset + j;
-                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
-
-                        float v_val = v[t_h_j_offset];
-                        float kv_val = v_val * k_val;
-                        float prev_state_val = state_prev[h_2d_i_j_offset];
-                        float temp_val = prev_state_val * g_val + kv_val;
-                        dst_data[t_h_j_offset] += temp_val * q_val;
-                        state_cur[h_2d_i_j_offset] = temp_val;
-                    }
-                }
-            }
-        }
-    #endif
-}
-
-
-static void ggml_compute_forward_gla(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_gla_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_map_unary
-
-static void ggml_compute_forward_map_unary_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_unary_op_f32_t fun) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_map_unary(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_unary_op_f32_t fun) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_map_unary_f32(params, dst, fun);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_map_binary
-
-static void ggml_compute_forward_map_binary_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(src1));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])),
-                (float *) ((char *) src1->data + i*(src1->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_map_binary(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_map_binary_f32(params, dst, fun);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_map_custom1
-
-static void ggml_compute_forward_map_custom1_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_custom1_op_f32_t fun) {
-
-    const struct ggml_tensor * a = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a);
-}
-
-// ggml_compute_forward_map_custom2
-
-static void ggml_compute_forward_map_custom2_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_custom2_op_f32_t fun) {
-
-    const struct ggml_tensor * a = dst->src[0];
-    const struct ggml_tensor * b = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a, b);
-}
-
-// ggml_compute_forward_map_custom3
-
-static void ggml_compute_forward_map_custom3_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_custom3_op_f32_t fun) {
-
-    const struct ggml_tensor * a = dst->src[0];
-    const struct ggml_tensor * b = dst->src[1];
-    const struct ggml_tensor * c = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a, b, c);
-}
-
-// ggml_compute_forward_map_custom1
-
-static void ggml_compute_forward_map_custom1(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * a = dst->src[0];
-
-    struct ggml_map_custom1_op_params p;
-    memcpy(&p, dst->op_params, sizeof(p));
-
-    p.fun(dst, a, params->ith, params->nth, p.userdata);
-}
-
-// ggml_compute_forward_map_custom2
-
-static void ggml_compute_forward_map_custom2(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * a = dst->src[0];
-    const struct ggml_tensor * b = dst->src[1];
-
-    struct ggml_map_custom2_op_params p;
-    memcpy(&p, dst->op_params, sizeof(p));
-
-    p.fun(dst, a, b, params->ith, params->nth, p.userdata);
-}
-
-// ggml_compute_forward_map_custom3
-
-static void ggml_compute_forward_map_custom3(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * a = dst->src[0];
-    const struct ggml_tensor * b = dst->src[1];
-    const struct ggml_tensor * c = dst->src[2];
-
-    struct ggml_map_custom3_op_params p;
-    memcpy(&p, dst->op_params, sizeof(p));
-
-    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
-}
-
-// ggml_compute_forward_cross_entropy_loss
-
-static void ggml_compute_forward_cross_entropy_loss_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
-    GGML_ASSERT(ggml_are_same_shape(src0, src1));
-    GGML_ASSERT(ggml_is_scalar(dst));
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
-    // TODO: handle transposed/permuted matrices
-    const int64_t nc = src0->ne[0];
-    const int64_t nr = ggml_nrows(src0);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    float * sums =  (float *) params->wdata;
-    float * st   = ((float *) params->wdata) + nth + ith*nc;
-    float sum_thread = 0.0f;
-
-    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
-
-    // rows per thread
-    const int64_t dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int64_t ir0 = dr*ith;
-    const int64_t ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t i1 = ir0; i1 < ir1; ++i1) {
-        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
-        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
-
-#ifndef NDEBUG
-        for (int64_t i = 0; i < nc; ++i) {
-            //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(s0[i]));
-            assert(!isnan(s1[i]));
-        }
-#endif
-
-        float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, s0);
-        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
-        assert(sum_softmax >= 0.0);
-
-        ggml_vec_add1_f32(nc, st, st, -sum_softmax);
-        ggml_vec_mul_f32(nc, st, st, s1);
-
-        float sum_st = 0.0f;
-        ggml_vec_sum_f32(nc, &sum_st, st);
-        sum_thread += sum_st;
-
-#ifndef NDEBUG
-        for (int64_t i = 0; i < nc; ++i) {
-            assert(!isnan(st[i]));
-            assert(!isinf(st[i]));
-        }
-#endif
-    }
-    sums[ith] = sum_thread;
-    ggml_barrier(params->threadpool);
-
-    if (ith == 0) {
-        float * dp = (float *) dst->data;
-        ggml_vec_sum_f32(nth, dp, sums);
-        dp[0] *= -1.0f / (float) nr;
-    }
-}
-
-static void ggml_compute_forward_cross_entropy_loss(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_cross_entropy_loss_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_cross_entropy_loss_back
-
-static void ggml_compute_forward_cross_entropy_loss_back_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * grad  = dst->src[0]; // gradient of forward pass output
-    const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
-    const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
-
-    GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_is_contiguous(src0f));
-    GGML_ASSERT(ggml_is_contiguous(src1f));
-    GGML_ASSERT(ggml_is_contiguous(grad));
-    GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
-
-    const int64_t ith = params->ith;
-    const int64_t nth = params->nth;
-
-    // TODO: handle transposed/permuted matrices
-    const int64_t nc = src0f->ne[0];
-    const int64_t nr = ggml_nrows(src0f);
-
-    // rows per thread
-    const int64_t dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int64_t ir0 = dr*ith;
-    const int64_t ir1 = MIN(ir0 + dr, nr);
-
-    const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
-
-    for (int64_t i1 = ir0; i1 < ir1; i1++) {
-        float       * ds0 = (float       *)((char       *) dst->data   + i1*dst->nb[1]);
-        const float * s0  = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
-        const float * s1  = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
-
-#ifndef NDEBUG
-        for (int64_t i = 0; i < nc; ++i) {
-            //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(s0[i]));
-            assert(!isnan(s1[i]));
-        }
-#endif
-
-        // soft_max
-        float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, s0);
-        const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
-        assert(sum > 0.0);
-        ggml_vec_scale_f32(nc, ds0, 1.0/sum);
-
-        // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
-        ggml_vec_sub_f32(nc, ds0, ds0, s1);
-        ggml_vec_scale_f32(nc, ds0, d_by_nr);
-
-#ifndef NDEBUG
-        for (int64_t i = 0; i < nc; ++i) {
-            assert(!isnan(ds0[i]));
-            assert(!isinf(ds0[i]));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_cross_entropy_loss_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_opt_step_adamw_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0         = dst->src[0];
-    const struct ggml_tensor * src0_grad    = dst->src[1];
-    const struct ggml_tensor * src0_grad_m  = dst->src[2];
-    const struct ggml_tensor * src0_grad_v  = dst->src[3];
-    const struct ggml_tensor * adamw_params = dst->src[4];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
-    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
-    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
-    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
-    const float alpha  = adamw_params_ptr[0];
-    const float beta1  = adamw_params_ptr[1];
-    const float beta2  = adamw_params_ptr[2];
-    const float eps    = adamw_params_ptr[3];
-    const float wd     = adamw_params_ptr[4];
-    const float beta1h = adamw_params_ptr[5];
-    const float beta2h = adamw_params_ptr[6];
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        const int64_t i03 = ir/(ne02*ne01);
-        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
-
-        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
-        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
-        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
-        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
-
-        for (int i00 = 0; i00 < ne00; ++i00) {
-            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
-            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
-
-            const float mh =       m[i00]*beta1h;
-            const float vh = sqrtf(v[i00]*beta2h) + eps;
-
-            // The weight decay is applied independently of the Adam momenta m and v.
-            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
-            // See: https://arxiv.org/pdf/1711.05101v3.pdf
-            w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
-        }
-    }
-}
-
-static void ggml_compute_forward_opt_step_adamw(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_opt_step_adamw_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
 /////////////////////////////////
 
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -12965,7 +1729,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
     }
 
     // extra_buffer op?
-    if (ggml_cpu_extra_compute_forward(params, tensor)) return;
+    if (ggml_cpu_extra_compute_forward(params, tensor)) {
+        return;
+    }
 
     switch (tensor->op) {
         case GGML_OP_DUP:
@@ -13068,6 +1834,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_group_norm(params, tensor);
             } break;
+        case GGML_OP_L2_NORM:
+            {
+                ggml_compute_forward_l2_norm(params, tensor);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 ggml_compute_forward_mul_mat(params, tensor);
@@ -13259,41 +2029,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_gla(params, tensor);
             } break;
-        case GGML_OP_MAP_UNARY:
+        case GGML_OP_RWKV_WKV7:
             {
-                ggml_unary_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_unary(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_BINARY:
-            {
-                ggml_binary_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_binary(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM1_F32:
-            {
-                ggml_custom1_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom1_f32(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM2_F32:
-            {
-                ggml_custom2_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom2_f32(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM3_F32:
-            {
-                ggml_custom3_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom3_f32(params, tensor, fun);
-            }
-            break;
+                ggml_compute_forward_rwkv_wkv7(params, tensor);
+            } break;
         case GGML_OP_MAP_CUSTOM1:
             {
                 ggml_compute_forward_map_custom1(params, tensor);
@@ -13309,6 +2048,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 ggml_compute_forward_map_custom3(params, tensor);
             }
             break;
+        case GGML_OP_CUSTOM:
+            {
+                ggml_compute_forward_custom(params, tensor);
+            }
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
             {
                 ggml_compute_forward_cross_entropy_loss(params, tensor);
@@ -13484,6 +2228,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
         case GGML_OP_RMS_NORM_BACK:
+        case GGML_OP_L2_NORM:
         case GGML_OP_GROUP_NORM:
         case GGML_OP_CONCAT:
         case GGML_OP_MUL_MAT:
@@ -13551,19 +2296,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_FLASH_ATTN_BACK:
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
+        case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
+        case GGML_OP_RWKV_WKV7:
             {
                 n_tasks = n_threads;
             } break;
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
-        case GGML_OP_RWKV_WKV6:
-        case GGML_OP_GATED_LINEAR_ATTN:
-        case GGML_OP_MAP_UNARY:
-        case GGML_OP_MAP_BINARY:
-        case GGML_OP_MAP_CUSTOM1_F32:
-        case GGML_OP_MAP_CUSTOM2_F32:
-        case GGML_OP_MAP_CUSTOM3_F32:
             {
                 n_tasks = 1;
             } break;
@@ -13597,6 +2338,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                     n_tasks = MIN(p.n_tasks, n_threads);
                 }
             } break;
+        case GGML_OP_CUSTOM:
+            {
+                struct ggml_custom_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p.n_tasks, n_threads);
+                }
+            } break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
@@ -13925,7 +2676,6 @@ struct ggml_cplan ggml_graph_plan(
         size_t cur = 0;
 
         if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
-
             switch (node->op) {
                 case GGML_OP_CPY:
                 case GGML_OP_DUP:
@@ -14034,9 +2784,10 @@ struct ggml_cplan ggml_graph_plan(
                     } break;
                 case GGML_OP_FLASH_ATTN_EXT:
                     {
-                        const int64_t ne00 = node->src[0]->ne[0]; // D
+                        const int64_t ne10 = node->src[1]->ne[0]; // DK
+                        const int64_t ne20 = node->src[2]->ne[0]; // DV
 
-                        cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
+                        cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
                     } break;
                 case GGML_OP_FLASH_ATTN_BACK:
                     {
@@ -14486,6 +3237,14 @@ int ggml_cpu_has_amx_int8(void) {
 #endif
 }
 
+int ggml_cpu_has_bmi2(void) {
+#if defined(__BMI2__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_fma(void) {
 #if defined(__FMA__)
     return 1;
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
index a84203f29..09f8382b9 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
@@ -511,6 +511,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
         if (ggml_cpu_has_fma()) {
             features.push_back({ "FMA", "1" });
         }
+        if (ggml_cpu_has_bmi2()) {
+            features.push_back({ "BMI2", "1" });
+        }
         if (ggml_cpu_has_avx512()) {
             features.push_back({ "AVX512", "1" });
         }
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
index e0482c593..f6374f789 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
@@ -55,6 +55,7 @@
 
 #include 
 #include 
+#include 
 
 #ifdef _MSC_VER
 #define NOINLINE __declspec(noinline)
@@ -1092,13 +1093,403 @@ class tinyBLAS_Q0_PPC {
        }
     }
 
-    template
-    void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+    template
+    void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array& comparray) {
         int64_t i, j;
         TA *aoffset = NULL;
         VA *vecOffset = NULL;
         TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
         TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
+        VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
+        VB t1, t2, t3, t4, t5, t6, t7, t8;
+        const vector signed char lowMask = vec_splats((signed char)0xF);
+        const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+        const vector signed char v8 = vec_splats((signed char)0x8);
+        aoffset = const_cast(a);
+        vecOffset = vec;
+        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
+        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
+        vector signed int vsum = {0};
+        vector signed int vsum2 = {0};
+
+        j = (rows >> 3);
+        if (j > 0) {
+            do {
+                aoffset1 = aoffset;
+                aoffset2 = aoffset1 + lda;
+                aoffset3 = aoffset2 + lda;
+                aoffset4 = aoffset3 + lda;
+                aoffset5 = aoffset4 + lda;
+                aoffset6 = aoffset5 + lda;
+                aoffset7 = aoffset6 + lda;
+                aoffset8 = aoffset7 + lda;
+                aoffset += 8 * lda;
+
+                i = (cols >> 2);
+                if (i > 0) {
+                    do {
+                        c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
+                        c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
+                        c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
+                        c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs));
+                        c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs));
+                        c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs));
+                        c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs));
+                        c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs));
+
+                        c1[0] = vec_and(c1[1], lowMask);
+                        c1[1] = vec_sr(c1[1], v4);
+                        c1[0] = vec_sub(c1[0], v8);
+                        c1[1] = vec_sub(c1[1], v8);
+                        vsum = vec_sum4s(c1[0], vsum);
+                        vsum2 = vec_sum4s(c1[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c2[0] = vec_and(c2[1], lowMask);
+                        c2[1] = vec_sr(c2[1], v4);
+                        c2[0] = vec_sub(c2[0], v8);
+                        c2[1] = vec_sub(c2[1], v8);
+                        vsum = vec_sum4s(c2[0], vsum);
+                        vsum2 = vec_sum4s(c2[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c3[0] = vec_and(c3[1], lowMask);
+                        c3[1] = vec_sr(c3[1], v4);
+                        c3[0] = vec_sub(c3[0], v8);
+                        c3[1] = vec_sub(c3[1], v8);
+                        vsum = vec_sum4s(c3[0], vsum);
+                        vsum2 = vec_sum4s(c3[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c4[0] = vec_and(c4[1], lowMask);
+                        c4[1] = vec_sr(c4[1], v4);
+                        c4[0] = vec_sub(c4[0], v8);
+                        c4[1] = vec_sub(c4[1], v8);
+                        vsum = vec_sum4s(c4[0], vsum);
+                        vsum2 = vec_sum4s(c4[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c5[0] = vec_and(c5[1], lowMask);
+                        c5[1] = vec_sr(c5[1], v4);
+                        c5[0] = vec_sub(c5[0], v8);
+                        c5[1] = vec_sub(c5[1], v8);
+                        vsum = vec_sum4s(c5[0], vsum);
+                        vsum2 = vec_sum4s(c5[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c6[0] = vec_and(c6[1], lowMask);
+                        c6[1] = vec_sr(c6[1], v4);
+                        c6[0] = vec_sub(c6[0], v8);
+                        c6[1] = vec_sub(c6[1], v8);
+                        vsum = vec_sum4s(c6[0], vsum);
+                        vsum2 = vec_sum4s(c6[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c7[0] = vec_and(c7[1], lowMask);
+                        c7[1] = vec_sr(c7[1], v4);
+                        c7[0] = vec_sub(c7[0], v8);
+                        c7[1] = vec_sub(c7[1], v8);
+                        vsum = vec_sum4s(c7[0], vsum);
+                        vsum2 = vec_sum4s(c7[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c8[0] = vec_and(c8[1], lowMask);
+                        c8[1] = vec_sr(c8[1], v4);
+                        c8[0] = vec_sub(c8[0], v8);
+                        c8[1] = vec_sub(c8[1], v8);
+                        vsum = vec_sum4s(c8[0], vsum);
+                        vsum2 = vec_sum4s(c8[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        t1 = vec_perm(c1[0], c2[0], swiz1);
+                        t2 = vec_perm(c1[0], c2[0], swiz2);
+                        t3 = vec_perm(c3[0], c4[0], swiz1);
+                        t4 = vec_perm(c3[0], c4[0], swiz2);
+                        t5 = vec_perm(t1, t3, swiz3);
+                        t6 = vec_perm(t1, t3, swiz4);
+                        t7 = vec_perm(t2, t4, swiz3);
+                        t8 = vec_perm(t2, t4, swiz4);
+                        vec_xst(t5, 0, vecOffset);
+                        vec_xst(t6, 0, vecOffset+16);
+                        vec_xst(t7, 0, vecOffset+32);
+                        vec_xst(t8, 0, vecOffset+48);
+
+                        t1 = vec_perm(c1[1], c2[1], swiz1);
+                        t2 = vec_perm(c1[1], c2[1], swiz2);
+                        t3 = vec_perm(c3[1], c4[1], swiz1);
+                        t4 = vec_perm(c3[1], c4[1], swiz2);
+                        t5 = vec_perm(t1, t3, swiz3);
+                        t6 = vec_perm(t1, t3, swiz4);
+                        t7 = vec_perm(t2, t4, swiz3);
+                        t8 = vec_perm(t2, t4, swiz4);
+                        vec_xst(t5, 0, vecOffset+64);
+                        vec_xst(t6, 0, vecOffset+80);
+                        vec_xst(t7, 0, vecOffset+96);
+                        vec_xst(t8, 0, vecOffset+112);
+
+                        t1 = vec_perm(c5[0], c6[0], swiz1);
+                        t2 = vec_perm(c5[0], c6[0], swiz2);
+                        t3 = vec_perm(c7[0], c8[0], swiz1);
+                        t4 = vec_perm(c7[0], c8[0], swiz2);
+                        t5 = vec_perm(t1, t3, swiz3);
+                        t6 = vec_perm(t1, t3, swiz4);
+                        t7 = vec_perm(t2, t4, swiz3);
+                        t8 = vec_perm(t2, t4, swiz4);
+                        vec_xst(t5, 0, vecOffset+128);
+                        vec_xst(t6, 0, vecOffset+144);
+                        vec_xst(t7, 0, vecOffset+160);
+                        vec_xst(t8, 0, vecOffset+176);
+
+                        t1 = vec_perm(c5[1], c6[1], swiz1);
+                        t2 = vec_perm(c5[1], c6[1], swiz2);
+                        t3 = vec_perm(c7[1], c8[1], swiz1);
+                        t4 = vec_perm(c7[1], c8[1], swiz2);
+                        t5 = vec_perm(t1, t3, swiz3);
+                        t6 = vec_perm(t1, t3, swiz4);
+                        t7 = vec_perm(t2, t4, swiz3);
+                        t8 = vec_perm(t2, t4, swiz4);
+                        vec_xst(t5, 0, vecOffset+192);
+                        vec_xst(t6, 0, vecOffset+208);
+                        vec_xst(t7, 0, vecOffset+224);
+                        vec_xst(t8, 0, vecOffset+240);
+
+                        aoffset1 += lda;
+                        aoffset2 += lda;
+                        aoffset3 += lda;
+                        aoffset4 += lda;
+                        aoffset5 += lda;
+                        aoffset6 += lda;
+                        aoffset7 += lda;
+                        aoffset8 += lda;
+                        vecOffset += 256;
+                        i--;
+                    } while (i > 0);
+                }
+                j--;
+            } while (j > 0);
+        }
+
+        if (rows & 4) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            aoffset4 = aoffset3 + lda;
+            aoffset += 4 * lda;
+
+            i = (cols >> 2);
+            if (i > 0) {
+                do {
+                    c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
+                    c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
+                    c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
+                    c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs));
+
+                    c1[0] = vec_and(c1[1], lowMask);
+                    c1[1] = vec_sr(c1[1], v4);
+                    c1[0] = vec_sub(c1[0], v8);
+                    c1[1] = vec_sub(c1[1], v8);
+                    vsum = vec_sum4s(c1[0], vsum);
+                    vsum2 = vec_sum4s(c1[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c2[0] = vec_and(c2[1], lowMask);
+                    c2[1] = vec_sr(c2[1], v4);
+                    c2[0] = vec_sub(c2[0], v8);
+                    c2[1] = vec_sub(c2[1], v8);
+                    vsum = vec_sum4s(c2[0], vsum);
+                    vsum2 = vec_sum4s(c2[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c3[0] = vec_and(c3[1], lowMask);
+                    c3[1] = vec_sr(c3[1], v4);
+                    c3[0] = vec_sub(c3[0], v8);
+                    c3[1] = vec_sub(c3[1], v8);
+                    vsum = vec_sum4s(c3[0], vsum);
+                    vsum2 = vec_sum4s(c3[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c4[0] = vec_and(c4[1], lowMask);
+                    c4[1] = vec_sr(c4[1], v4);
+                    c4[0] = vec_sub(c4[0], v8);
+                    c4[1] = vec_sub(c4[1], v8);
+                    vsum = vec_sum4s(c4[0], vsum);
+                    vsum2 = vec_sum4s(c4[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats( 0);
+
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    aoffset4 += lda;
+                    vecOffset += 128;
+                    i--;
+                } while (i > 0);
+            }
+        }
+
+        if (rows & 3) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            i = (cols >> 2);
+            if (i > 0) {
+                do {
+                    switch(rows) {
+                        case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
+                        case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
+                        case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
+                            break;
+                    }
+                    c1[0] = vec_and(c1[1], lowMask);
+                    c1[1] = vec_sr(c1[1], v4);
+                    c1[0] = vec_sub(c1[0], v8);
+                    c1[1] = vec_sub(c1[1], v8);
+                    vsum = vec_sum4s(c1[0], vsum);
+                    vsum2 = vec_sum4s(c1[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c2[0] = vec_and(c2[1], lowMask);
+                    c2[1] = vec_sr(c2[1], v4);
+                    c2[0] = vec_sub(c2[0], v8);
+                    c2[1] = vec_sub(c2[1], v8);
+                    vsum = vec_sum4s(c2[0], vsum);
+                    vsum2 = vec_sum4s(c2[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c3[0] = vec_and(c3[1], lowMask);
+                    c3[1] = vec_sr(c3[1], v4);
+                    c3[0] = vec_sub(c3[0], v8);
+                    c3[1] = vec_sub(c3[1], v8);
+                    vsum = vec_sum4s(c3[0], vsum);
+                    vsum2 = vec_sum4s(c3[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c4[0] = vec_and(c4[1], lowMask);
+                    c4[1] = vec_sr(c4[1], v4);
+                    c4[0] = vec_sub(c4[0], v8);
+                    c4[1] = vec_sub(c4[1], v8);
+                    vsum = vec_sum4s(c4[0], vsum);
+                    vsum2 = vec_sum4s(c4[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    vecOffset += 128;
+                    i--;
+                } while(i > 0);
+            }
+        }
+    }
+
+    template
+    void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+        int64_t i, j;
+        TB *aoffset = NULL;
+        VA *vecOffset = NULL;
+        TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+        TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
         __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
         VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
         VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
@@ -1111,24 +1502,24 @@ class tinyBLAS_Q0_PPC {
         vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
         vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
 
-        aoffset = const_cast(a);
+        aoffset = const_cast(a);
         vecOffset = vec;
         j = (rows >> 3);
         if (j > 0) {
             do {
-            aoffset1 = aoffset;
-            aoffset2 = aoffset1 + lda;
-            aoffset3 = aoffset2 + lda;
-            aoffset4 = aoffset3 + lda;
-            aoffset5 = aoffset4 + lda;
-            aoffset6 = aoffset5 + lda;
-            aoffset7 = aoffset6 + lda;
-            aoffset8 = aoffset7 + lda;
-            aoffset += 8 * lda;
+                aoffset1 = aoffset;
+                aoffset2 = aoffset1 + lda;
+                aoffset3 = aoffset2 + lda;
+                aoffset4 = aoffset3 + lda;
+                aoffset5 = aoffset4 + lda;
+                aoffset6 = aoffset5 + lda;
+                aoffset7 = aoffset6 + lda;
+                aoffset8 = aoffset7 + lda;
+                aoffset += 8 * lda;
 
-            i = (cols >> 3);
-            if (i > 0) {
-               do {
+                i = (cols >> 3);
+                if (i > 0) {
+                do {
                     C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
                     C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
                     C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
@@ -1156,10 +1547,10 @@ class tinyBLAS_Q0_PPC {
                     t7 = vec_perm(t2, t4, swiz3);
                     t8 = vec_perm(t2, t4, swiz4);
                     if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
+                        t5 = vec_xor(t5, xor_vector);
+                        t6 = vec_xor(t6, xor_vector);
+                        t7 = vec_xor(t7, xor_vector);
+                        t8 = vec_xor(t8, xor_vector);
                     }
                     vec_xst(t5, 0, vecOffset);
                     vec_xst(t6, 0, vecOffset+16);
@@ -1175,10 +1566,10 @@ class tinyBLAS_Q0_PPC {
                     t7 = vec_perm(t2, t4, swiz3);
                     t8 = vec_perm(t2, t4, swiz4);
                     if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
+                        t5 = vec_xor(t5, xor_vector);
+                        t6 = vec_xor(t6, xor_vector);
+                        t7 = vec_xor(t7, xor_vector);
+                        t8 = vec_xor(t8, xor_vector);
                     }
                     vec_xst(t5, 0, vecOffset+64);
                     vec_xst(t6, 0, vecOffset+80);
@@ -1194,10 +1585,10 @@ class tinyBLAS_Q0_PPC {
                     t7 = vec_perm(t2, t4, swiz3);
                     t8 = vec_perm(t2, t4, swiz4);
                     if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
+                        t5 = vec_xor(t5, xor_vector);
+                        t6 = vec_xor(t6, xor_vector);
+                        t7 = vec_xor(t7, xor_vector);
+                        t8 = vec_xor(t8, xor_vector);
                     }
                     vec_xst(t5, 0, vecOffset+128);
                     vec_xst(t6, 0, vecOffset+144);
@@ -1213,10 +1604,10 @@ class tinyBLAS_Q0_PPC {
                     t7 = vec_perm(t2, t4, swiz3);
                     t8 = vec_perm(t2, t4, swiz4);
                     if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
+                        t5 = vec_xor(t5, xor_vector);
+                        t6 = vec_xor(t6, xor_vector);
+                        t7 = vec_xor(t7, xor_vector);
+                        t8 = vec_xor(t8, xor_vector);
                     }
                     vec_xst(t5, 0, vecOffset+192);
                     vec_xst(t6, 0, vecOffset+208);
@@ -1240,11 +1631,11 @@ class tinyBLAS_Q0_PPC {
     }
 
     if (rows & 4) {
-            aoffset1 = aoffset;
-            aoffset2 = aoffset1 + lda;
-            aoffset3 = aoffset2 + lda;
-            aoffset4 = aoffset3 + lda;
-            aoffset += 4 * lda;
+        aoffset1 = aoffset;
+        aoffset2 = aoffset1 + lda;
+        aoffset3 = aoffset2 + lda;
+        aoffset4 = aoffset3 + lda;
+        aoffset += 4 * lda;
 
         i = (cols >> 3);
             if (i > 0) {
@@ -1311,7 +1702,7 @@ class tinyBLAS_Q0_PPC {
             aoffset2 = aoffset1 + lda;
             aoffset3 = aoffset2 + lda;
             i = (cols >> 3);
-        if (i > 0) {
+            if (i > 0) {
                 do {
                     switch(rows) {
                         case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
@@ -1527,13 +1918,18 @@ class tinyBLAS_Q0_PPC {
     void KERNEL_4x8(int64_t ii, int64_t jj) {
         vec_t vec_A[8], vec_B[16] = {0};
         acc_t acc_0, acc_1;
-        std::array comparray;
+        std::array comparray {};
         vector float fin_res[8] = {0};
         vector float vs[8] = {0};
+        bool isAblock_q4 = std::is_same_v;
         for (int l = 0; l < k; l++) {
             __builtin_mma_xxsetaccz(&acc_0);
             __builtin_mma_xxsetaccz(&acc_1);
-            packNormal((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+            if (std::is_same_v) {
+               packNormalInt4((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
+            } else {
+               packNormal((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+            }
             packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
             for(int x = 0; x < 8; x++) {
                 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1545,15 +1941,17 @@ class tinyBLAS_Q0_PPC {
                     *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
                 }
             }
-            auto aoffset = A+(ii*lda)+l;
-            for (int i = 0; i < 4; i++) {
-                comparray[i] = 0;
-                int ca = 0;
-                const int8_t *at = aoffset->qs;
-                for (int j = 0; j < 32; j++)
-                    ca += (int)*at++;
-                comparray[i] = ca;
-                aoffset += lda;
+            if (!isAblock_q4) {
+                auto aoffset = A+(ii*lda)+l;
+                for (int i = 0; i < 4; i++) {
+                    comparray[i] = 0;
+                    int ca = 0;
+                    auto *at = aoffset->qs;
+                    for (int j = 0; j < 32; j++)
+                        ca += (int)*at++;
+                    comparray[i] = ca;
+                    aoffset += lda;
+                }
             }
             compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
             compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
@@ -1565,13 +1963,18 @@ class tinyBLAS_Q0_PPC {
     void KERNEL_8x4(int64_t ii, int64_t jj) {
         vec_t vec_A[16], vec_B[8] = {0};
         acc_t acc_0, acc_1;
-        std::array comparray;
+        std::array comparray {};
         vector float fin_res[8] = {0};
         vector float vs[8] = {0};
+        bool isAblock_q4 = std::is_same_v;
         for (int l = 0; l < k; l++) {
             __builtin_mma_xxsetaccz(&acc_0);
             __builtin_mma_xxsetaccz(&acc_1);
-            packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            if (std::is_same_v) {
+               packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+            } else {
+               packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            }
             packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
             for(int x = 0; x < 8; x++) {
                 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1582,15 +1985,17 @@ class tinyBLAS_Q0_PPC {
                     *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
                 }
             }
-            auto aoffset = A+(ii*lda)+l;
-            for (int i = 0; i < 8; i++) {
-                comparray[i] = 0;
-                int ca = 0;
-                const int8_t *at = aoffset->qs;
-                for (int j = 0; j < 32; j++)
-                    ca += (int)*at++;
-                comparray[i] = ca;
-                aoffset += lda;
+            if (!isAblock_q4) {
+                auto aoffset = A+(ii*lda)+l;
+                for (int i = 0; i < 8; i++) {
+                    comparray[i] = 0;
+                    int ca = 0;
+                    auto *at = aoffset->qs;
+                    for (int j = 0; j < 32; j++)
+                        ca += (int)*at++;
+                    comparray[i] = ca;
+                    aoffset += lda;
+                }
             }
             compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
             compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
@@ -1602,15 +2007,20 @@ class tinyBLAS_Q0_PPC {
     void KERNEL_8x8(int64_t ii, int64_t jj) {
         vec_t vec_A[16], vec_B[16] = {0};
         acc_t acc_0, acc_1, acc_2, acc_3;
-        std::array comparray;
+        std::array comparray {};
         vector float fin_res[16] = {0};
         vector float vs[16] = {0};
+        bool isAblock_q4 = std::is_same_v;
         for (int l = 0; l < k; l++) {
             __builtin_mma_xxsetaccz(&acc_0);
             __builtin_mma_xxsetaccz(&acc_1);
             __builtin_mma_xxsetaccz(&acc_2);
             __builtin_mma_xxsetaccz(&acc_3);
-            packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            if (std::is_same_v) {
+               packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+            } else {
+               packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            }
             packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
             for(int x = 0; x < 8; x++) {
                 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1624,15 +2034,17 @@ class tinyBLAS_Q0_PPC {
                     *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
                 }
             }
-            auto aoffset = A+(ii*lda)+l;
-            for (int i = 0; i < 8; i++) {
-                comparray[i] = 0;
-                int ca = 0;
-                const int8_t *at = aoffset->qs;
-                for (int j = 0; j < 32; j++)
-                    ca += (int)*at++;
-                comparray[i] = ca;
-                aoffset += lda;
+            if (!isAblock_q4) {
+                auto aoffset = A+(ii*lda)+l;
+                for (int i = 0; i < 8; i++) {
+                    comparray[i] = 0;
+                    int ca = 0;
+                    auto *at = aoffset->qs;
+                    for (int j = 0; j < 32; j++)
+                        ca += (int)*at++;
+                    comparray[i] = ca;
+                    aoffset += lda;
+                }
             }
             compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
             compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
@@ -1653,16 +2065,17 @@ class tinyBLAS_Q0_PPC {
         int64_t duty = (tiles + nth - 1) / nth;
         int64_t start = duty * ith;
         int64_t end = start + duty;
-        vec_t vec_A[8], vec_B[8] = {0};
+        vec_t vec_A[8] = {0}, vec_B[8] = {0};
         vector signed int vec_C[4];
         acc_t acc_0;
+        bool isAblock_q4 = std::is_same_v;
 
         if (end > tiles)
             end = tiles;
         for (int64_t job = start; job < end; ++job) {
             int64_t ii = m0 + job / xtiles * RM;
             int64_t jj = n0 + job % xtiles * RN;
-            std::array comparray;
+            std::array comparray{};
             vector float res[4] = {0};
             vector float fin_res[4] = {0};
             vector float vs[4] = {0};
@@ -1673,7 +2086,11 @@ class tinyBLAS_Q0_PPC {
                 __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
                 __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
                 __builtin_mma_xxsetaccz(&acc_0);
-                packNormal((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+                if (isAblock_q4) {
+                   packNormalInt4((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
+                } else {
+                   packNormal((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+                }
                 packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
                 for(int x = 0; x < 8; x+=4) {
                     __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1687,17 +2104,18 @@ class tinyBLAS_Q0_PPC {
                     }
                 }
                 __builtin_mma_disassemble_acc(vec_C, &acc_0);
-                auto aoffset = A+(ii*lda)+l;
-                for (int i = 0; i < RM; i++) {
-                    comparray[i] = 0;
-                    int ca = 0;
-                    const int8_t *at = aoffset->qs;
-                    for (int j = 0; j < 32; j++)
-                        ca += (int)*at++;
-                    comparray[i] = ca;
-                    aoffset += lda;
+                if (!isAblock_q4) {
+                    auto aoffset = A+(ii*lda)+l;
+                    for (int i = 0; i < RM; i++) {
+                        comparray[i] = 0;
+                        int ca = 0;
+                        auto *at = aoffset->qs;
+                        for (int j = 0; j < 32; j++)
+                            ca += (int)*at++;
+                        comparray[i] = ca;
+                        aoffset += lda;
+                    }
                 }
-
                 for (int i = 0; i < RM; i++) {
                     CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
                     res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
@@ -2013,6 +2431,7 @@ class tinyBLAS_PPC {
             }
         }
     }
+
     void KERNEL_4x4(int64_t ii, int64_t jj) {
         vec_t vec_A[4], vec_B[4], vec_C[4];
         acc_t acc_0;
@@ -2259,15 +2678,27 @@ class tinyBLAS_PPC {
             vec_t vec_C[4];
             acc_t acc_0;
             __builtin_mma_xxsetaccz(&acc_0);
-            vec_t vec_A[4], vec_B[4];
+            vec_t vec_A[4] {0}, vec_B[4] = {0};
             for (int l=0; l= 4 && RM == 1) {
+                /* 'GEMV Forwarding' concept is used in first two conditional loops.
+                 * when one of the matrix has a single row/column, the elements are
+                 * broadcasted, instead of using packing routine to prepack the
+                 * matrix elements.
+                 */
+                if (RM == 1) {
                     TA* a = const_cast(A+(ii)*lda+l);
-                    packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
+                    packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
                     vec_A[0] = (vec_t)vec_xl(0,a);
                     vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
                     vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
                     vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
+                } else if (RN == 1) {
+                    packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
+                    TB* b = const_cast(B+(jj)*ldb+l);
+                    vec_B[0] = (vec_t)vec_xl(0,b);
+                    vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
+                    vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
+                    vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
                 } else {
                     packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
                     packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
@@ -2371,8 +2802,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
     assert(params->ith < params->nth);
 
     // only enable sgemm for prompt processing
+#if !defined(__MMA__)
     if (n < 2)
         return false;
+#endif
 
     if (Ctype != GGML_TYPE_F32)
         return false;
@@ -2503,8 +2936,8 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             params->ith, params->nth};
         tb.matmul(m, n);
         return true;
-
 #elif defined(__MMA__)
+    //TO-DO: Remove this condition once gemv forwarding is enabled.
         if (n < 8 && n != 4)
            return false;
         if (m < 8 && m != 4)
@@ -2516,7 +2949,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             params->ith, params->nth};
         tb.matmul(m, n);
         return true;
-
 #else
         return false;
 #endif
@@ -2541,6 +2973,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             params->ith, params->nth};
         tb.matmul(m, n);
         return true;
+#elif defined(__MMA__)
+    //TO-DO: Remove this condition once gemv forwarding is enabled.
+        if (n < 8 && n != 4)
+           return false;
+        if (m < 8 && m != 4)
+           return false;
+        tinyBLAS_Q0_PPC tb{
+            k, (const block_q4_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            params->ith, params->nth};
+        tb.matmul(m, n);
+        return true;
 #else
         return false;
 #endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp
new file mode 100644
index 000000000..66b8da68f
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp
@@ -0,0 +1,8692 @@
+#include "ops.h"
+
+#include "ggml-cpu.h"
+#include "ggml-impl.h"
+#include "binary-ops.h"
+#include "unary-ops.h"
+#include "vec.h"
+
+#include 
+
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+
+// disable POSIX deprecation warnings
+// these functions are never going away, anyway
+#pragma warning(disable: 4996)
+
+// unreachable code because of multiple instances of code after GGML_ABORT
+#pragma warning(disable: 4702)
+#endif
+
+// ggml_compute_forward_dup
+
+static void ggml_compute_forward_dup_same_cont(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+    GGML_ASSERT(src0->type == dst->type);
+
+    const size_t nb0 = ggml_type_size(src0->type);
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by blocks
+    const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
+    const int dr = (nk + nth - 1) / nth;
+    const int k0 = dr * ith;
+    const int k1 = MIN(k0 + dr, nk);
+
+    if (k0 < k1) {
+        memcpy(
+            ((char *)  dst->data + k0*nb0),
+            ((char *) src0->data + k0*nb0),
+            (k1 - k0) * nb0);
+    }
+}
+
+static void ggml_compute_forward_dup_f16(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
+
+    if (ggml_is_contiguous(dst)) {
+        if (nb00 == sizeof(ggml_fp16_t)) {
+            if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+                            }
+
+                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        }
+        return;
+    }
+
+    // dst counters
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
+
+                        if (++i10 == ne00) {
+                            i10 = 0;
+                            if (++i11 == ne01) {
+                                i11 = 0;
+                                if (++i12 == ne02) {
+                                    i12 = 0;
+                                    if (++i13 == ne03) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("fatal error"); // TODO: implement
+    }
+}
+
+static void ggml_compute_forward_dup_bf16(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
+
+    if (ggml_is_contiguous(dst)) {
+        if (nb00 == sizeof(ggml_bf16_t)) {
+            if (dst->type == GGML_TYPE_BF16) {
+                size_t id = 0;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
+                            }
+
+                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_BF16) {
+                size_t id = 0;
+                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        }
+        return;
+    }
+
+    // dst counters
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_BF16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
+
+                        if (++i10 == ne00) {
+                            i10 = 0;
+                            if (++i11 == ne01) {
+                                i11 = 0;
+                                if (++i12 == ne02) {
+                                    i12 = 0;
+                                    if (++i13 == ne03) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("fatal error"); // TODO: implement
+    }
+}
+
+static void ggml_compute_forward_dup_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    if (ggml_is_contiguous(dst)) {
+        // TODO: simplify
+        if (nb00 == sizeof(float)) {
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            quantize_row_q(src0_ptr, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_BF16) {
+                size_t id = 0;
+                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        }
+
+        return;
+    }
+
+    // dst counters
+
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(float));
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_BF16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("fatal error"); // TODO: implement
+    }
+}
+
+// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
+static void ggml_compute_forward_dup_bytes(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(src0->type == dst->type);
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
+        ggml_compute_forward_dup_same_cont(params, dst);
+        return;
+    }
+
+    const size_t type_size = ggml_type_size(src0->type);
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ggml_are_same_shape(src0, dst) &&
+        nb00 == type_size && nb0 == type_size) {
+        // copy by rows
+        const size_t rs = ggml_row_size(src0->type, ne00);
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    if (ggml_is_contiguous(dst)) {
+        size_t id = 0;
+        char * dst_ptr = (char *) dst->data;
+        const size_t rs = ne00 * type_size;
+
+        if (nb00 == type_size) {
+            // src0 is contigous on first dimension, copy by rows
+            for (int64_t i03 = 0; i03 < ne03; i03++) {
+                for (int64_t i02 = 0; i02 < ne02; i02++) {
+                    id += rs * ir0;
+                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                        memcpy(dst_ptr + id, src0_ptr, rs);
+                        id += rs;
+                    }
+                    id += rs * (ne01 - ir1);
+                }
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            for (int64_t i03 = 0; i03 < ne03; i03++) {
+                for (int64_t i02 = 0; i02 < ne02; i02++) {
+                    id += rs * ir0;
+                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                        for (int64_t i00 = 0; i00 < ne00; i00++) {
+                            const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, type_size);
+
+                            id += type_size;
+                        }
+                    }
+                    id += rs * (ne01 - ir1);
+                }
+            }
+        }
+
+        return;
+    }
+
+    // dst counters
+    int64_t k10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    // number of blocks in a row
+    const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
+    const int64_t nk0  = ne0  / ggml_blck_size(dst->type);
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            k10 += nk00 * ir0;
+            while (k10 >= nk0) {
+                k10 -= nk0;
+                if (++i11 == ne1) {
+                    i11 = 0;
+                    if (++i12 == ne2) {
+                        i12 = 0;
+                        if (++i13 == ne3) {
+                            i13 = 0;
+                        }
+                    }
+                }
+            }
+            for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                for (int64_t k00 = 0; k00 < nk00; k00++) {
+                    const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                          char * dst_ptr  = ((char *)  dst->data + k10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                    memcpy(dst_ptr, src0_ptr, type_size);
+
+                    if (++k10 == nk0) {
+                        k10 = 0;
+                        if (++i11 == ne1) {
+                            i11 = 0;
+                            if (++i12 == ne2) {
+                                i12 = 0;
+                                if (++i13 == ne3) {
+                                    i13 = 0;
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+            k10 += nk00 * (ne01 - ir1);
+            while (k10 >= nk0) {
+                k10 -= nk0;
+                if (++i11 == ne1) {
+                    i11 = 0;
+                    if (++i12 == ne2) {
+                        i12 = 0;
+                        if (++i13 == ne3) {
+                            i13 = 0;
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_dup_q(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+
+    size_t qk = ggml_blck_size(type);
+    const int64_t nr = ggml_nelements(src1) / qk;
+
+    // destination must be contiguous in the first dimension
+    GGML_ASSERT(nb10 == ggml_type_size(dst->type));
+    // must either have first dimension large enough to hold a row, or fully contiguous
+    GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+
+        uint32_t i = ir * qk;
+
+        const int64_t i03 = i/(ne00 * ne01 * ne02);
+        const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+        const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+        const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+        const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+        const int64_t i13 = i/(ne10 * ne11 * ne12);
+        const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+        const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+        const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+        const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+        dequantize_row_q(
+                (const void *) ((char *) src0->data + x_offset),
+                     (float *) ((char *)  dst->data + dst_offset), qk);
+    }
+}
+
+void ggml_compute_forward_dup(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (src0->type == dst->type) {
+        ggml_compute_forward_dup_bytes(params, dst);
+        return;
+    }
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_dup_f16(params, dst);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_dup_bf16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_dup_f32(params, dst);
+            } break;
+        default:
+            {
+                if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_dup_q(params, dst);
+                    break;
+                }
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_add
+
+static void ggml_compute_forward_add_q_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const ggml_type type = src0->type;
+    const ggml_type dtype = dst->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ggml_is_quantized(src0->type));
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        // src1 and dst are same shape as src0 => same indices
+        const int i13 = i03;
+        const int i12 = i02;
+        const int i11 = i01;
+
+        const int i3 = i03;
+        const int i2 = i02;
+        const int i1 = i01;
+
+        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
+        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
+
+        assert(ne00 % 32 == 0);
+
+        // unquantize row from src0 to temp buffer
+        dequantize_row_q(src0_row, wdata, ne00);
+        // add src1
+        ggml_vec_acc_f32(ne00, wdata, src1_row);
+        // quantize row to dst
+        if (quantize_row_q != NULL) {
+            quantize_row_q(wdata, dst_row, ne00);
+        } else {
+            memcpy(dst_row, wdata, ne0*nb0);
+        }
+    }
+}
+
+void ggml_compute_forward_add(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_add_non_quantized(params, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+            {
+                ggml_compute_forward_add_q_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_add1
+
+static void ggml_compute_forward_add1_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+#ifdef GGML_USE_ACCELERATE
+        GGML_UNUSED(ggml_vec_add1_f32);
+
+        vDSP_vadd(
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
+                (float *) ((char *) src1->data), 0,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
+                ne0);
+#else
+        ggml_vec_add1_f32(ne0,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+               *(float *) src1->data);
+#endif
+    }
+}
+
+static void ggml_compute_forward_add1_f16_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_f16_f16(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_q_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
+
+    // we don't support permuted src0
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ggml_is_quantized(src0->type));
+    GGML_ASSERT(dst->type == src0->type);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
+        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));
+
+        assert(ne0 % 32 == 0);
+
+        // unquantize row from src0 to temp buffer
+        dequantize_row_q(src0_row, wdata, ne0);
+        // add src1
+        ggml_vec_acc1_f32(ne0, wdata, v);
+        // quantize row to dst
+        quantize_row_q(wdata, dst_row, ne0);
+    }
+}
+
+static void ggml_compute_forward_add1_bf16_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_bf16_bf16(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+void ggml_compute_forward_add1(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add1_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                if (src1->type == GGML_TYPE_F16) {
+                    ggml_compute_forward_add1_f16_f16(params, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add1_f16_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                if (src1->type == GGML_TYPE_BF16) {
+                    ggml_compute_forward_add1_bf16_bf16(params, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add1_bf16_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+            {
+                ggml_compute_forward_add1_q_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_acc
+
+static void ggml_compute_forward_acc_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    // view src0 and dst with these strides and data offset inbytes during acc
+    // nb0 is implicitly element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+    // src0 and dst as viewed during acc
+    const size_t nb0 = ggml_element_size(src0);
+
+    const size_t nb00 = nb0;
+    const size_t nb01 = nb1;
+    const size_t nb02 = nb2;
+    const size_t nb03 = nb3;
+
+    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));
+    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+#ifdef GGML_USE_ACCELERATE
+        vDSP_vadd(
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);
+#else
+        ggml_vec_add_f32(nc,
+                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+#endif
+    }
+}
+
+void ggml_compute_forward_acc(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_acc_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sum
+
+static void ggml_compute_forward_sum_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+    assert(src0->nb[0] == sizeof(float));
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+
+    ggml_float sum     = 0;
+    ggml_float row_sum = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f32_ggf(ne00,
+                        &row_sum,
+                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+                sum += row_sum;
+            }
+        }
+    }
+    ((float *) dst->data)[0] = sum;
+}
+
+static void ggml_compute_forward_sum_f16(
+    const ggml_compute_params * params,
+          ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+
+    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+
+    float sum = 0;
+    float row_sum = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f16_ggf(ne00,
+                    &row_sum,
+                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
+                sum += row_sum;
+            }
+        }
+    }
+    ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
+}
+
+static void ggml_compute_forward_sum_bf16(
+    const ggml_compute_params * params,
+          ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+
+    assert(src0->nb[0] == sizeof(ggml_bf16_t));
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+
+    float sum = 0;
+    float row_sum = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_bf16_ggf(ne00,
+                    &row_sum,
+                    (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
+                sum += row_sum;
+            }
+        }
+    }
+    ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
+}
+
+void ggml_compute_forward_sum(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sum_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_sum_f16(params, dst);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_sum_bf16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sum_rows
+
+static void ggml_compute_forward_sum_rows_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(ne0 == 1);
+    GGML_ASSERT(ne1 == ne01);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    for (int64_t i3 = 0; i3 < ne03; i3++) {
+        for (int64_t i2 = 0; i2 < ne02; i2++) {
+            for (int64_t i1 = 0; i1 < ne01; i1++) {
+                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
+                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
+                float row_sum = 0;
+                ggml_vec_sum_f32(ne00, &row_sum, src_row);
+                dst_row[0] = row_sum;
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_sum_rows(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sum_rows_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_mean
+
+static void ggml_compute_forward_mean_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(src0->nb[0] == sizeof(float));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    assert(ne0 == 1);
+    assert(ne1 == ne01);
+    assert(ne2 == ne02);
+    assert(ne3 == ne03);
+
+    GGML_UNUSED(ne0);
+    GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2);
+    GGML_UNUSED(ne3);
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f32(ne00,
+                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+
+                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_mean(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mean_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_argmax
+
+static void ggml_compute_forward_argmax_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(src0->nb[0] == sizeof(float));
+    assert(dst->nb[0] == sizeof(float));
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb0 = dst->nb[0];
+
+    for (int64_t i1 = 0; i1 < ne01; i1++) {
+        float * src = (float *) ((char *) src0->data + i1*nb01);
+        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);
+        int v = 0;
+        ggml_vec_argmax_f32(ne00, &v, src);
+        dst_[0] = v;
+    }
+}
+
+void ggml_compute_forward_argmax(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_argmax_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_count_equal
+
+static void ggml_compute_forward_count_equal_i32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    GGML_ASSERT(src0->type == GGML_TYPE_I32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(dst->type == GGML_TYPE_I64);
+
+    const int64_t nr = ggml_nrows(src0);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    int64_t * sums = (int64_t *) params->wdata;
+    int64_t sum_thread = 0;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 =  ir                        / (ne02*ne01);
+        const int64_t i02 = (ir - i03*ne03)            /       ne01;
+        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
+
+        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
+        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
+
+        for (int64_t i00 = 0; i00 < ne00; ++i00) {
+            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
+            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
+
+            sum_thread += val0 == val1;
+        }
+    }
+    if (ith != 0) {
+        sums[ith] = sum_thread;
+    }
+    ggml_barrier(params->threadpool);
+
+    if (ith != 0) {
+        return;
+    }
+
+    for (int ith_other = 1; ith_other < nth; ++ith_other) {
+        sum_thread += sums[ith_other];
+    }
+    *((int64_t *) dst->data) = sum_thread;
+}
+
+void ggml_compute_forward_count_equal(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_count_equal_i32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_repeat
+
+static void ggml_compute_forward_repeat_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne0/ne00);
+    const int nr1 = (int)(ne1/ne01);
+    const int nr2 = (int)(ne2/ne02);
+    const int nr3 = (int)(ne3/ne03);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3;  i3++) {
+        for                     (int k3 = 0; k3 < ne03; k3++) {
+            for                 (int i2 = 0; i2 < nr2;  i2++) {
+                for             (int k2 = 0; k2 < ne02; k2++) {
+                    for         (int i1 = 0; i1 < nr1;  i1++) {
+                        for     (int k1 = 0; k1 < ne01; k1++) {
+                            for (int i0 = 0; i0 < nr0;  i0++) {
+                                ggml_vec_cpy_f32(ne00,
+                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),
+                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_repeat_f16(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne0/ne00);
+    const int nr1 = (int)(ne1/ne01);
+    const int nr2 = (int)(ne2/ne02);
+    const int nr3 = (int)(ne3/ne03);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3;  i3++) {
+        for                     (int k3 = 0; k3 < ne03; k3++) {
+            for                 (int i2 = 0; i2 < nr2;  i2++) {
+                for             (int k2 = 0; k2 < ne02; k2++) {
+                    for         (int i1 = 0; i1 < nr1;  i1++) {
+                        for     (int k1 = 0; k1 < ne01; k1++) {
+                            for (int i0 = 0; i0 < nr0;  i0++) {
+                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
+                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
+                                // ggml_vec_cpy_f16(ne00, y, x)
+                                for (int i = 0; i < ne00; ++i) {
+                                    y[i]  = x[i];
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_repeat(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_I16:
+            {
+                ggml_compute_forward_repeat_f16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_repeat_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_repeat_back
+
+static void ggml_compute_forward_repeat_back_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne00/ne0);
+    const int nr1 = (int)(ne01/ne1);
+    const int nr2 = (int)(ne02/ne2);
+    const int nr3 = (int)(ne03/ne3);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (ggml_is_contiguous(dst)) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
+    } else {
+        for         (int k3 = 0; k3 < ne3; k3++) {
+            for     (int k2 = 0; k2 < ne2; k2++) {
+                for (int k1 = 0; k1 < ne1; k1++) {
+                    ggml_vec_set_f32(ne0,
+                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
+                        0);
+                }
+            }
+        }
+    }
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3; i3++) {
+        for                     (int k3 = 0; k3 < ne3; k3++) {
+            for                 (int i2 = 0; i2 < nr2; i2++) {
+                for             (int k2 = 0; k2 < ne2; k2++) {
+                    for         (int i1 = 0; i1 < nr1; i1++) {
+                        for     (int k1 = 0; k1 < ne1; k1++) {
+                            for (int i0 = 0; i0 < nr0; i0++) {
+                                ggml_vec_acc_f32(ne0,
+                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
+                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_repeat_back(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_repeat_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_concat
+
+static void ggml_compute_forward_concat_any(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    const size_t len = ggml_type_size(src0->type);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+    GGML_ASSERT(dim >= 0 && dim < 4);
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = src0->ne[dim];
+
+    const char * x;
+
+    // TODO: smarter multi-theading
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < ne0; i0++) {
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        x = (const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03;
+                    } else {
+                        x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
+                    }
+
+                    char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
+
+                    memcpy(y, x, len);
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_concat_i8(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+    GGML_ASSERT(dim >= 0 && dim < 4);
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = src0->ne[dim];
+
+    const int8_t * x;
+
+    // TODO: smarter multi-theading
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < ne0; i0++) {
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        x = (const int8_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
+                    } else {
+                        x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
+                    }
+
+                    int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_concat_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+    GGML_ASSERT(dim >= 0 && dim < 4);
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = src0->ne[dim];
+
+    const ggml_fp16_t * x;
+
+    // TODO: smarter multi-theading
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < ne0; i0++) {
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        x = (const ggml_fp16_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
+                    } else {
+                        x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
+                    }
+
+                    ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_concat_f32(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+    GGML_ASSERT(dim >= 0 && dim < 4);
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = src0->ne[dim];
+
+    const float * x;
+
+    // TODO: smarter multi-theading
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < ne0; i0++) {
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
+                    } else {
+                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
+                    }
+
+                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_concat(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_I16:
+            {
+                ggml_compute_forward_concat_f16(params, dst);
+            } break;
+        case GGML_TYPE_I8:
+            {
+                ggml_compute_forward_concat_i8(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_concat_f32(params, dst);
+            } break;
+        default:
+            {
+                ggml_compute_forward_concat_any(params, dst);
+            }
+    }
+}
+
+// ggml_compute_forward_gelu
+
+static void ggml_compute_forward_gelu_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            GGML_UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gelu_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_gelu_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_gelu_quick
+
+static void ggml_compute_forward_gelu_quick_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_quick_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu_quick_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_quick_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            GGML_UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu_quick(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gelu_quick_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_gelu_quick_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_silu
+
+static void ggml_compute_forward_silu_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            GGML_UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_silu_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+// ggml_compute_forward_leaky_relu
+
+static void ggml_compute_forward_leaky_relu_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_leaky_relu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
+    }
+}
+
+static void ggml_compute_forward_leaky_relu_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+    assert(dst->nb[0]  == sizeof(ggml_fp16_t));
+    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_leaky_relu_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
+    }
+}
+
+void ggml_compute_forward_leaky_relu(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_leaky_relu_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_leaky_relu_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_silu_back
+
+static void ggml_compute_forward_silu_back_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * grad = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    assert(ggml_is_contiguous_1(grad));
+    assert(ggml_is_contiguous_1(src1));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src1, dst));
+    assert(ggml_are_same_shape(src1, grad));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1->ne[0];
+    const int nr = ggml_nrows(src1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_backward_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src1->data + i1*(src1->nb[1])),
+                (float *) ((char *) grad->data + i1*(grad->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu_back_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * grad = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    assert(ggml_is_contiguous_1(grad));
+    assert(ggml_is_contiguous_1(src1));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src1, dst));
+    assert(ggml_are_same_shape(src1, grad));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1->ne[0];
+    const int nr = ggml_nrows(src1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_backward_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
+                (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
+
+    #ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            GGML_UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+    #endif
+    }
+}
+
+void ggml_compute_forward_silu_back(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_back_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_silu_back_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_norm
+
+static void ggml_compute_forward_norm_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    GGML_ASSERT(eps >= 0.0f);
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float sum = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)x[i00];
+                }
+
+                float mean = sum/ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                ggml_float sum2 = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    float v = x[i00] - mean;
+                    y[i00] = v;
+                    sum2 += (ggml_float)(v*v);
+                }
+
+                float variance = sum2/ne00;
+                const float scale = 1.0f/sqrtf(variance + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_norm(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_group_rms_norm
+
+static void ggml_compute_forward_rms_norm_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    GGML_ASSERT(eps >= 0.0f);
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float sum = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)(x[i00] * x[i00]);
+                }
+
+                const float mean = sum/ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                memcpy(y, x, ne00 * sizeof(float));
+                // for (int i00 = 0; i00 < ne00; i00++) {
+                //     y[i00] = x[i00];
+                // }
+
+                const float scale = 1.0f/sqrtf(mean + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_rms_norm(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rms_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_rms_norm_back_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
+    const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                // src1 is same shape as src0 => same indices
+                const int64_t i11 = i01;
+                const int64_t i12 = i02;
+                const int64_t i13 = i03;
+
+                const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                const float * x  = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
+
+                ggml_float sum_xx  = 0.0;
+                ggml_float sum_xdz = 0.0;
+
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum_xx  += (ggml_float)(x[i00] * x[i00]);
+                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);
+                }
+
+                //const float mean     = (float)(sum_xx)/ne00;
+                const float mean_eps = (float)(sum_xx)/ne00 + eps;
+                const float sum_eps  = (float)(sum_xx) + eps*ne00;
+                //const float mean_xdz = (float)(sum_xdz)/ne00;
+                // we could cache rms from forward pass to improve performance.
+                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
+                //const float rms      = sqrtf(mean_eps);
+                const float rrms     = 1.0f / sqrtf(mean_eps);
+                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
+
+                {
+                    // z = rms_norm(x)
+                    //
+                    // rms_norm(src1) =
+                    //     scale(
+                    //         src1,
+                    //         div(
+                    //             1,
+                    //             sqrt(
+                    //                 add(
+                    //                     scale(
+                    //                         sum(
+                    //                             sqr(
+                    //                                 src1)),
+                    //                         (1.0/N)),
+                    //                     eps))));
+
+                    // postorder:
+                    // ## op    args         grad
+                    // 00 param src1         grad[#00]
+                    // 01 const 1
+                    // 02 sqr   (#00)        grad[#02]
+                    // 03 sum   (#02)        grad[#03]
+                    // 04 const 1/N
+                    // 05 scale (#03, #04)   grad[#05]
+                    // 06 const eps
+                    // 07 add   (#05, #06)   grad[#07]
+                    // 08 sqrt  (#07)        grad[#08]
+                    // 09 div   (#01,#08)    grad[#09]
+                    // 10 scale (#00,#09)    grad[#10]
+                    //
+                    // backward pass, given grad[#10]
+                    // #10: scale
+                    // grad[#00] += scale(grad[#10],#09)
+                    // grad[#09] += sum(mul(grad[#10],#00))
+                    // #09: div
+                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
+                    // #08: sqrt
+                    // grad[#07] += mul(grad[#08], div(0.5, #08))
+                    // #07: add
+                    // grad[#05] += grad[#07]
+                    // #05: scale
+                    // grad[#03] += scale(grad[#05],#04)
+                    // #03: sum
+                    // grad[#02] += repeat(grad[#03], #02)
+                    // #02:
+                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
+                    //
+                    // substitute and simplify:
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
+                    // grad[#02] = repeat(grad[#03], #02)
+                    // grad[#02] = repeat(scale(grad[#05],#04), #02)
+                    // grad[#02] = repeat(scale(grad[#07],#04), #02)
+                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
+                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
+                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
+                    // a = b*c + d*e
+                    // a = b*c*f/f + d*e*f/f
+                    // a = (b*c*f + d*e*f)*(1/f)
+                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
+                    // a = (b + d*e/c)*c
+                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
+                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
+                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
+                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
+                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
+                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
+                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
+                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
+                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                }
+                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                // post-order:
+                // dx := x
+                // dx := scale(dx,-mean_xdz/mean_eps)
+                // dx := add(dx, dz)
+                // dx := scale(dx, rrms)
+                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
+                ggml_vec_cpy_f32  (ne00, dx, x);
+                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
+                ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
+                ggml_vec_acc_f32  (ne00, dx, dz);
+                ggml_vec_scale_f32(ne00, dx, rrms);
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_rms_norm_back(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rms_norm_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_group_norm
+
+static void ggml_compute_forward_group_norm_f32(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // TODO: optimize
+
+    float eps;
+    memcpy(&eps, dst->op_params + 1, sizeof(float));
+
+    int n_channels = src0->ne[2];
+    int n_groups = dst->op_params[0];
+    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
+    for (int i = ith; i < n_groups; i += nth) {
+        int start = i * n_channels_per_group;
+        int end = start + n_channels_per_group;
+        if (end > n_channels) {
+            end = n_channels;
+        }
+        int step = end - start;
+
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            ggml_float sum = 0.0;
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    ggml_float sumr = 0.0;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        sumr += (ggml_float)x[i00];
+                    }
+                    sum += sumr;
+                }
+            }
+            const float mean = sum / (ne00 * ne01 * step);
+
+            ggml_float sum2 = 0.0;
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+
+                    ggml_float sumr = 0.0;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        float v = x[i00] - mean;
+                        y[i00] = v;
+                        sumr += (ggml_float)(v * v);
+                    }
+                    sum2 += sumr;
+                }
+            }
+            const float variance = sum2 / (ne00 * ne01 * step);
+            const float scale = 1.0f / sqrtf(variance + eps);
+
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+                    ggml_vec_scale_f32(ne00, y, scale);
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_group_norm(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_group_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_l2_norm
+
+static void ggml_compute_forward_l2_norm_f32(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    GGML_ASSERT(eps >= 0.0f);
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float sum = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)(x[i00] * x[i00]);
+                }
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                memcpy(y, x, ne00 * sizeof(float));
+
+                const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_l2_norm(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_l2_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_out_prod
+
+static void ggml_compute_forward_out_prod_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+    GGML_ASSERT(ne2 == ne12);
+    GGML_ASSERT(ne3 == ne13);
+
+    GGML_ASSERT(ne2 % ne02 == 0);
+    GGML_ASSERT(ne3 % ne03 == 0);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+    if (ith == 0) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
+    }
+    ggml_barrier(params->threadpool);
+
+    // dst[:,:,:,:] = 0
+    // for i2,i3:
+    //   for i1:
+    //     for i01:
+    //       for i0:
+    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+    // parallelize by last three dimensions
+
+    // total rows in dst
+    const int64_t nr = ne1*ne2*ne3;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    // block-tiling attempt
+    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
+    const int64_t blck_1 = 16;
+
+    // dps == dst per src0, used for group query attention
+    const int64_t dps2 = ne2 / ne02;
+    const int64_t dps3 = ne3 / ne03;
+
+    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
+        const int64_t bir1 = MIN(bir + blck_1, ir1);
+        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
+            const int64_t bne01 = MIN(bi01 + blck_0, ne01);
+            for (int64_t ir = bir; ir < bir1; ++ir) {
+                // dst indices
+                const int64_t i3 = ir/(ne2*ne1);
+                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                const int64_t i02 = i2 / dps2;
+                const int64_t i03 = i3 / dps3;
+
+                //const int64_t i10 = i1;
+                const int64_t i12 = i2;
+                const int64_t i13 = i3;
+
+#if GGML_VEC_MAD_UNROLL > 2
+                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
+                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
+
+                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
+                }
+                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
+
+                    ggml_vec_mad_f32(ne0, d, s0, *s1);
+                }
+#else
+                for (int64_t i01 = bi01; i01 < bne01; ++i01) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+                    ggml_vec_mad_f32(ne0, d, s0, *s1);
+                }
+#endif
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_out_prod_q_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // we don't support permuted src0 dim0
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+
+    // dst dim0 cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+    if (ith == 0) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
+    }
+    ggml_barrier(params->threadpool);
+
+    // parallelize by last three dimensions
+
+    // total rows in dst
+    const int64_t nr = ne1*ne2*ne3;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    // dst[:,:,:,:] = 0
+    // for i2,i3:
+    //   for i1:
+    //     for i01:
+    //       for i0:
+    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        // dst indices
+        const int64_t i3 = ir/(ne2*ne1);
+        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        const int64_t i02 = i2;
+        const int64_t i03 = i3;
+
+        //const int64_t i10 = i1;
+        const int64_t i12 = i2;
+        const int64_t i13 = i3;
+
+        for (int64_t i01 = 0; i01 < ne01; ++i01) {
+            const int64_t i11 = i01;
+
+            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+            dequantize_row_q(s0, wdata, ne0);
+            ggml_vec_mad_f32(ne0, d, wdata, *s1);
+        }
+    }
+}
+
+void ggml_compute_forward_out_prod(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+            {
+                ggml_compute_forward_out_prod_q_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ABORT("fatal error"); // todo
+                // ggml_compute_forward_out_prod_f16_f32(params, dst);
+            }
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_out_prod_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_scale
+
+static void ggml_compute_forward_scale_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    // scale factor
+    float v;
+    memcpy(&v, dst->op_params, sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb1 = dst->nb[1];
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        if (dst->data != src0->data) {
+            // src0 is same shape as dst => same indices
+            memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
+        }
+        ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
+    }
+}
+
+void ggml_compute_forward_scale(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_scale_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_set
+
+static void ggml_compute_forward_set_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    // view src0 and dst with these strides and data offset inbytes during set
+    // nb0 is implicitly element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+    // src0 and dst as viewed during set
+    const size_t nb0 = ggml_element_size(src0);
+
+    const int im0 = (ne10 == 0 ? 0 : ne10-1);
+    const int im1 = (ne11 == 0 ? 0 : ne11-1);
+    const int im2 = (ne12 == 0 ? 0 : ne12-1);
+    const int im3 = (ne13 == 0 ? 0 : ne13-1);
+
+    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+        ggml_vec_cpy_f32(nc,
+                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+    }
+}
+
+static void ggml_compute_forward_set_i32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    // view src0 and dst with these strides and data offset inbytes during set
+    // nb0 is implicitly element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+    // src0 and dst as viewed during set
+    const size_t nb0 = ggml_element_size(src0);
+
+    const int im0 = (ne10 == 0 ? 0 : ne10-1);
+    const int im1 = (ne11 == 0 ? 0 : ne11-1);
+    const int im2 = (ne12 == 0 ? 0 : ne12-1);
+    const int im3 = (ne13 == 0 ? 0 : ne13-1);
+
+    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
+
+    GGML_ASSERT(nb10 == sizeof(int32_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+        ggml_vec_cpy_i32(nc,
+                (int32_t *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+    }
+}
+
+void ggml_compute_forward_set(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_set_f32(params, dst);
+            } break;
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_set_i32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_cpy
+
+void ggml_compute_forward_cpy(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    ggml_compute_forward_dup(params, dst);
+}
+
+// ggml_compute_forward_cont
+
+void ggml_compute_forward_cont(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    ggml_compute_forward_dup(params, dst);
+}
+
+// ggml_compute_forward_reshape
+
+void ggml_compute_forward_reshape(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    // NOP
+    GGML_UNUSED(params);
+    GGML_UNUSED(dst);
+}
+
+// ggml_compute_forward_view
+
+void ggml_compute_forward_view(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    // NOP
+    GGML_UNUSED(params);
+    GGML_UNUSED(dst);
+}
+
+// ggml_compute_forward_permute
+
+void ggml_compute_forward_permute(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    // NOP
+    GGML_UNUSED(params);
+    GGML_UNUSED(dst);
+}
+
+// ggml_compute_forward_transpose
+
+void ggml_compute_forward_transpose(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    // NOP
+    GGML_UNUSED(params);
+    GGML_UNUSED(dst);
+}
+
+// ggml_compute_forward_get_rows
+
+static void ggml_compute_forward_get_rows_q(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    const ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == ggml_type_size(type));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        dequantize_row_q(
+                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_f16(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(ggml_fp16_t));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        ggml_fp16_to_fp32_row(
+            (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                       (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_bf16(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(ggml_bf16_t));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        ggml_bf16_to_fp32_row(
+            (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                        (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(float));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        ggml_vec_cpy_f32(nc,
+                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
+                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+    }
+}
+
+void ggml_compute_forward_get_rows(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+            {
+                ggml_compute_forward_get_rows_q(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_get_rows_f16(params, dst);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_get_rows_bf16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_get_rows_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    //static bool first = true;
+    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    //if (first) {
+    //    first = false;
+    //} else {
+    //    for (int k = 0; k < dst->ne[1]; ++k) {
+    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
+    //            for (int i = 0; i < 16; ++i) {
+    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+    //            }
+    //            printf("\n");
+    //        }
+    //        printf("\n");
+    //    }
+    //    printf("\n");
+    //    exit(0);
+    //}
+}
+
+// ggml_compute_forward_get_rows_back
+
+static void ggml_compute_forward_get_rows_back_f32_f16(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+    memset(dst->data, 0, ggml_nbytes(dst));
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    GGML_ASSERT( dst->ne[0] == nc);
+    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        for (int j = 0; j < nc; ++j) {
+            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
+            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
+        }
+    }
+}
+
+static void ggml_compute_forward_get_rows_back_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+    memset(dst->data, 0, ggml_nbytes(dst));
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    GGML_ASSERT( dst->ne[0] == nc);
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        ggml_vec_add_f32(nc,
+                (float *) ((char *)  dst->data + r*dst->nb[1]),
+                (float *) ((char *)  dst->data + r*dst->nb[1]),
+                (float *) ((char *) src0->data + i*src0->nb[1]));
+    }
+}
+
+void ggml_compute_forward_get_rows_back(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_get_rows_back_f32_f16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_get_rows_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    //static bool first = true;
+    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    //if (first) {
+    //    first = false;
+    //} else {
+    //    for (int k = 0; k < dst->ne[1]; ++k) {
+    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
+    //            for (int i = 0; i < 16; ++i) {
+    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+    //            }
+    //            printf("\n");
+    //        }
+    //        printf("\n");
+    //    }
+    //    printf("\n");
+    //    exit(0);
+    //}
+}
+
+// ggml_compute_forward_diag
+
+static void ggml_compute_forward_diag_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    // TODO: handle transposed/permuted matrices
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(ne00 == ne0);
+    GGML_ASSERT(ne00 == ne1);
+    GGML_ASSERT(ne01 == 1);
+    GGML_ASSERT(ne02 == ne2);
+    GGML_ASSERT(ne03 == ne3);
+
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb0  == sizeof(float));
+
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = 0; i2 < ne2; i2++) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);
+                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
+                for (int i0 = 0; i0 < i1; i0++) {
+                    d[i0] = 0;
+                }
+                d[i1] = s[i1];
+                for (int i0 = i1+1; i0 < ne0; i0++) {
+                    d[i0] = 0;
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_diag(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_diag_mask_inf
+
+static void ggml_compute_forward_diag_mask_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst,
+        const float value) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int  n_past  = ((int32_t *) dst->op_params)[0];
+    const bool inplace = src0->data == dst->data;
+
+    GGML_ASSERT(n_past >= 0);
+
+    if (!inplace) {
+        if (ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    // TODO: handle transposed/permuted matrices
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+    const int nr = src0->ne[1];
+    const int nz = n/nr;
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int k = 0; k < nz; k++) {
+        for (int j = ith; j < nr; j += nth) {
+            for (int i = n_past; i < nc; i++) {
+                if (i > n_past + j) {
+                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_diag_mask_inf(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+void ggml_compute_forward_diag_mask_zero(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_mask_f32(params, dst, 0);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_soft_max
+
+static void ggml_compute_forward_soft_max_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    assert(ggml_is_contiguous(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+
+    // TODO: handle transposed/permuted matrices
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    //const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
+    // TODO: is this supposed to be ceil instead of floor?
+    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
+    const uint32_t n_head      = ne02;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        // ALiBi
+        const uint32_t h = (i1/ne01)%ne02; // head
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+        float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
+
+        // broadcast the mask across rows
+        ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+        float       * mp_f32 = src1 ? (float       *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+
+        ggml_vec_cpy_f32  (nc, wp, sp);
+        ggml_vec_scale_f32(nc, wp, scale);
+        if (mp_f32) {
+            if (use_f16) {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
+                }
+            } else {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += slope*mp_f32[i];
+                }
+            }
+        }
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(wp[i]));
+        }
+#endif
+
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, wp);
+
+        ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
+        assert(sum > 0.0);
+
+        sum = 1.0/sum;
+        ggml_vec_scale_f32(nc, dp, sum);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(dp[i]));
+            assert(!isinf(dp[i]));
+        }
+#endif
+    }
+}
+
+void ggml_compute_forward_soft_max(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+
+// ggml_compute_forward_soft_max_ext_back
+
+static void ggml_compute_forward_soft_max_ext_back_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_are_same_shape(src1, dst));
+
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+    GGML_ASSERT(max_bias == 0.0f);
+
+    // TODO: handle transposed/permuted matrices
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(dy[i]));
+            assert(!isnan(y[i]));
+        }
+#endif
+        // Jii = yi - yi*yi
+        // Jij = -yi*yj
+        // J = diag(y)-y.T*y
+        // dx = J * dy
+        // dxk = sum_i(Jki * dyi)
+        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
+        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
+        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
+        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
+        // dxk = -yk * dot(y, dy) + yk*dyk
+        // dxk = yk * (- dot(y, dy) + dyk)
+        // dxk = yk * (dyk - dot(y, dy))
+        //
+        // post-order:
+        // dot_y_dy := dot(y, dy)
+        // dx := dy
+        // dx := dx - dot_y_dy
+        // dx := dx * y
+
+        // linear runtime, no additional memory
+        float dot_y_dy = 0;
+        ggml_vec_dot_f32  (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
+        ggml_vec_cpy_f32  (nc, dx, dy);
+        ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
+        ggml_vec_mul_f32  (nc, dx, dx, y);
+        ggml_vec_scale_f32(nc, dx, scale);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(dx[i]));
+            assert(!isinf(dx[i]));
+        }
+#endif
+    }
+}
+
+void ggml_compute_forward_soft_max_ext_back(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_ext_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_clamp
+
+static void ggml_compute_forward_clamp_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    float min;
+    float max;
+    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    for (int j = ith; j < n; j += nth) {
+        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
+        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+
+        for (int i = 0; i < nc; i++) {
+            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
+        }
+    }
+}
+
+static void ggml_compute_forward_clamp_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    float min;
+    float max;
+    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    for (int j = ith; j < n; j += nth) {
+        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *)  dst->data + j*nb1);
+        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
+
+        for (int i = 0; i < nc; i++) {
+            float v = GGML_FP16_TO_FP32(src0_ptr[i]);
+            dst_ptr[i] = GGML_FP32_TO_FP16(MAX(MIN(v, max), min));
+        }
+    }
+}
+
+void ggml_compute_forward_clamp(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_clamp_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_clamp_f16(params, dst);
+            } break;
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_Q8_K:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_I64:
+        case GGML_TYPE_F64:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_rope
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+    return 1 - MIN(1, MAX(0, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+static void ggml_rope_cache_init(
+     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+     float * cache, float sin_sign, float theta_scale) {
+    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
+    float theta = theta_base;
+    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
+        rope_yarn(
+            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+        );
+        cache[i0 + 1] *= sin_sign;
+
+        theta *= theta_scale;
+    }
+}
+
+static void ggml_mrope_cache_init(
+     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
+     float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+     float * cache, float sin_sign, float theta_scale) {
+    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
+    float theta_t = theta_base_t;
+    float theta_h = theta_base_h;
+    float theta_w = theta_base_w;
+    float theta_e = theta_base_e;  // extra position id for vision encoder
+    int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
+    int sec_w = sections[1] + sections[0];
+    int sec_e = sections[2] + sec_w;
+    GGML_ASSERT(sect_dims <= ne0);
+
+    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
+
+        int sector = (i0 / 2) % sect_dims;
+        if (indep_sects) {
+            // compute theta independently for each dim sections
+            // (i.e. reset corresponding theta when `i0` go from one section to another)
+            if (sector == 0) {
+                theta_t = theta_base_t;
+            }
+            else if (sector == sections[0]) {
+                theta_h = theta_base_h;;
+            }
+            else if (sector == sec_w) {
+                theta_w = theta_base_w;
+            }
+            else if (sector == sec_e) {
+                theta_e = theta_base_e;
+            }
+        }
+
+        float theta = theta_t;
+        if (sector >= sections[0] && sector < sec_w) {
+            theta = theta_h;
+        }
+        else if (sector >= sec_w && sector < sec_w + sections[2]) {
+            theta = theta_w;
+        }
+        else if (sector >= sec_w + sections[2]) {
+            theta = theta_e;
+        }
+
+        rope_yarn(
+            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+        );
+        cache[i0 + 1] *= sin_sign;
+
+        theta_t *= theta_scale;
+        theta_w *= theta_scale;
+        theta_h *= theta_scale;
+        theta_e *= theta_scale;
+    }
+}
+
+static void ggml_compute_forward_rope_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst,
+        const bool forward) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    int sections[4];
+
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(§ions,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(dst);
+
+    GGML_ASSERT(n_dims <= ne0);
+    GGML_ASSERT(n_dims % 2 == 0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, multimodal rotary position embedding
+    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+    if (is_mrope) {
+        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
+    }
+
+    if (is_vision) {
+        GGML_ASSERT(n_dims == ne0/2);
+    }
+
+    const float * freq_factors = NULL;
+    if (src2 != NULL) {
+        GGML_ASSERT(src2->type == GGML_TYPE_F32);
+        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+        freq_factors = (const float *) src2->data;
+    }
+
+    // backward process uses inverse rotation by cos and sin.
+    // cos and sin build a rotation matrix, where the inverse is the transpose.
+    // this essentially just switches the sign of sin.
+    const float sin_sign = forward ? 1.0f : -1.0f;
+
+    const int32_t * pos = (const int32_t *) src1->data;
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
+        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
+
+            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+            if (!is_mrope) {
+                const int64_t p = pos[i2];
+                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+            else {
+                const int64_t p_t = pos[i2];
+                const int64_t p_h = pos[i2 + ne2];
+                const int64_t p_w = pos[i2 + ne2 * 2];
+                const int64_t p_e = pos[i2 + ne2 * 3];
+                ggml_mrope_cache_init(
+                    p_t, p_h, p_w, p_e, sections, is_vision,
+                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+
+            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
+                if (is_neox || is_mrope) {
+                    if (is_vision){
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = src[0];
+                            const float x1 = src[n_dims];
+
+                            dst_data[0]      = x0*cos_theta - x1*sin_theta;
+                            dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
+                        }
+                    } else {
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = src[0];
+                            const float x1 = src[n_dims/2];
+
+                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                        }
+                    }
+                } else {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float x0 = src[0];
+                        const float x1 = src[1];
+
+                        dst_data[0] = x0*cos_theta - x1*sin_theta;
+                        dst_data[1] = x0*sin_theta + x1*cos_theta;
+                    }
+                }
+
+                if (is_vision) {
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                        const int64_t ic = i0/2;
+
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                        const float x0 = src[0];
+                        const float x1 = src[n_dims];
+
+                        dst_data[0]      = x0*cos_theta - x1*sin_theta;
+                        dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
+                    }
+                } else {
+                    // fill the remain channels with data from src tensor
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        dst_data[0] = src[0];
+                        dst_data[1] = src[1];
+                    }
+                }
+            }
+        }
+    }
+}
+
+// TODO: deduplicate f16/f32 code
+static void ggml_compute_forward_rope_f16(
+        const ggml_compute_params * params,
+        ggml_tensor * dst,
+        const bool forward) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    int sections[4];
+
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(§ions,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
+
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(dst);
+
+    GGML_ASSERT(n_dims <= ne0);
+    GGML_ASSERT(n_dims % 2 == 0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+    if (is_mrope) {
+        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
+    }
+
+    if (is_vision) {
+        GGML_ASSERT(n_dims == ne0/2);
+    }
+
+    const float * freq_factors = NULL;
+    if (src2 != NULL) {
+        GGML_ASSERT(src2->type == GGML_TYPE_F32);
+        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+        freq_factors = (const float *) src2->data;
+    }
+
+    // backward process uses inverse rotation by cos and sin.
+    // cos and sin build a rotation matrix, where the inverse is the transpose.
+    // this essentially just switches the sign of sin.
+    const float sin_sign = forward ? 1.0f : -1.0f;
+
+    const int32_t * pos = (const int32_t *) src1->data;
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+
+            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+            if (!is_mrope) {
+                const int64_t p = pos[i2];
+                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+            else {
+                const int64_t p_t = pos[i2];
+                const int64_t p_h = pos[i2 + ne2];
+                const int64_t p_w = pos[i2 + ne2 * 2];
+                const int64_t p_e = pos[i2 + ne2 * 3];
+                ggml_mrope_cache_init(
+                    p_t, p_h, p_w, p_e, sections, is_vision,
+                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
+                if (is_neox || is_mrope) {
+                    if (is_vision) {
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = GGML_FP16_TO_FP32(src[0]);
+                            const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
+
+                            dst_data[0]      = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        }
+                    } else {
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = GGML_FP16_TO_FP32(src[0]);
+                            const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+
+                            dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        }
+                    }
+                } else {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[1]);
+
+                        dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                    }
+                }
+
+                if (is_vision) {
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                        const int64_t ic = i0/2;
+
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
+
+                        dst_data[0]      = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                    }
+                } else {
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        dst_data[0] = src[0];
+                        dst_data[1] = src[1];
+                    }
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_rope(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_rope_f16(params, dst, true);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rope_f32(params, dst, true);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_rope_back
+
+void ggml_compute_forward_rope_back(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_rope_f16(params, dst, false);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rope_f32(params, dst, false);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_conv_transpose_1d
+
+static void ggml_compute_forward_conv_transpose_1d_f16_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00*ne01*ne02;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (ith == 0) {
+        memset(params->wdata, 0, params->wsize);
+
+        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ne02 + i02] = src[i00];
+                    }
+                }
+            }
+        }
+
+        // permute source data (src1) from (L x Cin) to (Cin x L)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+            ggml_fp16_t * dst_data = wdata;
+
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
+                    dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
+                }
+            }
+        }
+
+        // need to zero dst since we are accumulating into it
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
+    ggml_barrier(params->threadpool);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
+    // total rows in dst
+    const int nr = ne1;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
+    ggml_fp16_t * const wdata_src = wdata + nk;
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
+        for (int i10 = 0; i10 < ne10; i10++) {
+            const int i1n = i10*ne11;
+            for (int i00 = 0; i00 < ne00; i00++) {
+                float v = 0;
+                ggml_vec_dot_f16(ne02, &v, 0,
+                        (ggml_fp16_t *)    wdata_src + i1n, 0,
+                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
+                dst_data[i10*s0 + i00] += v;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_transpose_1d_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00*ne01*ne02;
+
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (ith == 0) {
+        memset(params->wdata, 0, params->wsize);
+
+        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+        {
+            float * const wdata = (float *) params->wdata + 0;
+
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    float * dst_data = wdata + i01*ne00*ne02;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ne02 + i02] = src[i00];
+                    }
+                }
+            }
+        }
+
+        // prepare source data (src1)
+        {
+            float * const wdata = (float *) params->wdata + nk;
+            float * dst_data = wdata;
+
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
+                    dst_data[i10*ne11 + i11] = src[i10];
+                }
+            }
+        }
+
+        // need to zero dst since we are accumulating into it
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
+    ggml_barrier(params->threadpool);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
+    // total rows in dst
+    const int nr = ne1;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * const wdata     = (float *) params->wdata + 0;
+    float * const wdata_src = wdata + nk;
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        float * wdata_kernel = wdata + i1*ne02*ne00;
+        for (int i10 = 0; i10 < ne10; i10++) {
+            const int i1n = i10*ne11;
+            for (int i00 = 0; i00 < ne00; i00++) {
+                float v = 0;
+                ggml_vec_dot_f32(ne02, &v, 0,
+                        wdata_src + i1n, 0,
+                        wdata_kernel + i00*ne02, 0, 1);
+                dst_data[i10*s0 + i00] += v;
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_conv_transpose_1d(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_conv_transpose_1d_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_im2col_f32
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst:  result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne13 : ne12;
+    const int64_t IC = is_2D ? ne12 : ne11;
+    const int64_t IH = is_2D ? ne11 : 1;
+    const int64_t IW = ne10;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne2 : 1;
+    const int64_t OW = ne1;
+
+    int ofs0 = is_2D ? nb13 : nb12;
+    int ofs1 = is_2D ? nb12 : nb11;
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        float * const wdata = (float *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+                for (int64_t iow = 0; iow < OW; iow++) {
+                    for (int64_t iic = ith; iic < IC; iic += nth) {
+
+                        // micro kernel
+                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+                                } else {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+
+// ggml_compute_forward_im2col_f16
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst:  result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f16(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne13 : ne12;
+    const int64_t IC = is_2D ? ne12 : ne11;
+    const int64_t IH = is_2D ? ne11 : 1;
+    const int64_t IW = ne10;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne2 : 1;
+    const int64_t OW = ne1;
+
+    int ofs0 = is_2D ? nb13 : nb12;
+    int ofs1 = is_2D ? nb12 : nb11;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+                for (int64_t iow = 0; iow < OW; iow++) {
+                    for (int64_t iic = ith; iic < IC; iic += nth) {
+
+                        // micro kernel
+                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+                                } else {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_im2col(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+    switch (dst->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_im2col_f16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_im2col_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_im2col_back_f32
+
+void ggml_compute_forward_im2col_back_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
+    const ggml_tensor * src1 = dst->src[1]; // convolution kernel
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne3 : ne2;
+    const int64_t IC = is_2D ? ne2 : ne1;
+    const int64_t IH = is_2D ? ne1 : 1;
+    const int64_t IW = ne0;
+
+    const int64_t KH = is_2D ? ne11 : 1;
+    const int64_t KW = ne10;
+
+    const int64_t OH = is_2D ? ne02 : 1;
+    const int64_t OW = ne01;
+
+    int ofs0 = is_2D ? nb3 : nb2;
+    int ofs1 = is_2D ? nb2 : nb1;
+
+    GGML_ASSERT(nb0  == sizeof(float));
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        float * const wdata = (float *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t iic = ith; iic < IC; iic += nth) {
+                for (int64_t iih = 0; iih < IH; iih++) {
+                    for (int64_t iiw = 0; iiw < IW; iiw++) {
+
+                        // micro kernel
+                        float grad = 0.0f;
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                // For s0 > 1 some values were skipped over in the forward pass.
+                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
+                                const int64_t tmpw = (iiw + p0 - ikw*d0);
+                                if (tmpw % s0 != 0) {
+                                    continue;
+                                }
+                                const int64_t iow = tmpw / s0;
+
+                                // Equivalent logic as above except for s1.
+                                int64_t ioh;
+                                if (is_2D) {
+                                    const int64_t tmph = iih + p1 - ikh*d1;
+
+                                    if (tmph % s1 != 0) {
+                                        continue;
+                                    }
+
+                                    ioh = tmph / s1;
+                                } else {
+                                    ioh = 0;
+                                }
+
+                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
+                                    continue;
+                                }
+
+                                const float * const grad_in = (const float *) src0->data
+                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                                grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
+                            }
+                        }
+                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
+                        dst_data[iih*IW + iiw] = grad;
+                    }
+                }
+            }
+        }
+    }
+}
+
+// ggml_compute_forward_conv_transpose_2d
+
+void ggml_compute_forward_conv_transpose_2d(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00*ne01*ne02*ne03;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (ith == 0) {
+        memset(params->wdata, 0, params->wsize);
+
+        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+            for (int64_t i03 = 0; i03 < ne03; i03++) {
+                for (int64_t i02 = 0; i02 < ne02; i02++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
+                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
+                    for (int64_t i01 = 0; i01 < ne01; i01++) {
+                        for (int64_t i00 = 0; i00 < ne00; i00++) {
+                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
+                        }
+                    }
+                }
+            }
+        }
+
+        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+            for (int i12 = 0; i12 < ne12; i12++) {
+                for (int i11 = 0; i11 < ne11; i11++) {
+                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
+                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
+                    for (int i10 = 0; i10 < ne10; i10++) {
+                        dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
+                    }
+                }
+            }
+        }
+
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
+    ggml_barrier(params->threadpool);
+
+    const int32_t stride = ggml_get_op_params_i32(dst, 0);
+
+    // total patches in dst
+    const int np = ne2;
+
+    // patches per thread
+    const int dp = (np + nth - 1)/nth;
+
+    // patch range for this thread
+    const int ip0 = dp*ith;
+    const int ip1 = MIN(ip0 + dp, np);
+
+    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+    ggml_fp16_t * const wdata_src = wdata + nk;
+
+    for (int i2 = ip0; i2 < ip1; i2++) { // Cout
+        float * dst_data = (float *)((char *) dst->data + i2*nb2);
+        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
+        for (int i11 = 0; i11 < ne11; i11++) {
+            for (int i10 = 0; i10 < ne10; i10++) {
+                const int i1n = i11*ne10*ne12 + i10*ne12;
+                for (int i01 = 0; i01 < ne01; i01++) {
+                    for (int i00 = 0; i00 < ne00; i00++) {
+                        float v = 0;
+                        ggml_vec_dot_f16(ne03, &v, 0,
+                                wdata_src + i1n, 0,
+                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
+                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
+                    }
+                }
+            }
+        }
+    }
+}
+
+// ggml_compute_forward_pool_1d_sk_p0
+
+static void ggml_compute_forward_pool_1d_sk_p0(
+        const ggml_compute_params * params,
+        const ggml_op_pool op,
+        const int k,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src = dst->src[0];
+
+    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    const char * cdata = (const char *)src->data;
+    const char * const data_end = cdata + ggml_nbytes(src);
+    float * drow = (float *)dst->data;
+
+    const int64_t rs = dst->ne[0];
+
+    while (cdata < data_end) {
+        const void * srow = (const void *)cdata;
+        int j = 0;
+        for (int64_t i = 0; i < rs; ++i) {
+            switch (op) {
+                case GGML_OP_POOL_AVG:   drow[i] = 0;        break;
+                case GGML_OP_POOL_MAX:   drow[i] = -FLT_MAX; break;
+                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+            }
+            for (int ki = 0; ki < k; ++ki) {
+                const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
+                switch (op) {
+                    case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break;
+                    case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break;
+                    case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error");
+                }
+                ++j;
+            }
+            switch (op) {
+                case GGML_OP_POOL_AVG:         drow[i] /= k; break;
+                case GGML_OP_POOL_MAX:                       break;
+                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+            }
+        }
+
+        cdata += src->nb[1];
+        drow  += rs;
+    }
+}
+
+// ggml_compute_forward_pool_1d
+
+void ggml_compute_forward_pool_1d(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    ggml_op_pool op = static_cast(opts[0]);
+    const int k0 = opts[1];
+    const int s0 = opts[2];
+    const int p0 = opts[3];
+    GGML_ASSERT(p0 == 0); // padding not supported
+    GGML_ASSERT(k0 == s0); // only s = k supported
+
+    ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
+}
+
+// ggml_compute_forward_pool_2d
+
+void ggml_compute_forward_pool_2d(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src = dst->src[0];
+
+    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    ggml_op_pool op = static_cast(opts[0]);
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+    const char * cdata = (const char*)src->data;
+    const char * const data_end = cdata + ggml_nbytes(src);
+
+    const int64_t px = dst->ne[0];
+    const int64_t py = dst->ne[1];
+    const int64_t pa = px * py;
+
+    float * dplane = (float *)dst->data;
+
+    const int ka = k0 * k1;
+    const int offset0 = -p0;
+    const int offset1 = -p1;
+
+    while (cdata < data_end) {
+        for (int oy = 0; oy < py; ++oy) {
+            float * const drow = dplane + oy * px;
+            for (int ox = 0; ox < px; ++ox) {
+                float * const out =  drow + ox;
+                switch (op) {
+                    case GGML_OP_POOL_AVG:     *out = 0;        break;
+                    case GGML_OP_POOL_MAX:     *out = -FLT_MAX; break;
+                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+                }
+
+                const int ix = offset0 + ox * s0;
+                const int iy = offset1 + oy * s1;
+
+                for (int ky = 0; ky < k1; ++ky) {
+                    if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
+                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
+                    for (int kx = 0; kx < k0; ++kx) {
+                        int j = ix + kx;
+                        if (j < 0 || j >= src->ne[0]) continue;
+                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
+                        switch (op) {
+                            case GGML_OP_POOL_AVG:                     *out += srow_j; break;
+                            case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break;
+                            case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
+                        }
+                    }
+                }
+                switch (op) {
+                    case GGML_OP_POOL_AVG:           *out /= ka; break;
+                    case GGML_OP_POOL_MAX:                       break;
+                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+                }
+            }
+        }
+
+        cdata  += src->nb[2];
+        dplane += pa;
+    }
+}
+
+// ggml_compute_forward_pool_2d_back
+
+void ggml_compute_forward_pool_2d_back(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src  = dst->src[0];
+    const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
+
+    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    ggml_op_pool op = static_cast(opts[0]);
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+
+    char       * cdata  = (char       *) dst->data;
+    const char * cdataf = (const char *) dstf->data;
+    const char * const data_end = cdata + ggml_nbytes(dst);
+
+    GGML_ASSERT(params->ith == 0);
+    memset(cdata, 0, ggml_nbytes(dst));
+
+    const int64_t px = src->ne[0];
+    const int64_t py = src->ne[1];
+    const int64_t pa = px * py;
+
+    const float * splane = (const float *) src->data;
+
+    const int ka = k0 * k1;
+    const int offset0 = -p0;
+    const int offset1 = -p1;
+
+    while (cdata < data_end) {
+        for (int oy = 0; oy < py; ++oy) {
+            const float * const srow = splane + oy * px;
+            for (int ox = 0; ox < px; ++ox) {
+                const float grad0 = srow[ox];
+
+                const int ix = offset0 + ox * s0;
+                const int iy = offset1 + oy * s1;
+
+                if (op == GGML_OP_POOL_MAX) {
+                    float maxval = -FLT_MAX;
+                    int kxmax = -1;
+                    int kymax = -1;
+
+                    for (int ky = 0; ky < k1; ++ky) {
+                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
+                            continue;
+                        }
+                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
+                        for (int kx = 0; kx < k0; ++kx) {
+                            int j = ix + kx;
+                            if (j < 0 || j >= dst->ne[0]) {
+                                continue;
+                            }
+
+                            const float val = dst->type == GGML_TYPE_F32 ?
+                                ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
+                            if (val <= maxval) {
+                                continue;
+                            }
+
+                            maxval = val;
+                            kxmax = kx;
+                            kymax = ky;
+                        }
+                    }
+
+                    if (kxmax == -1 || kymax == -1) {
+                        continue;
+                    }
+
+                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
+                    const int j = ix + kxmax;
+                    if (dst->type == GGML_TYPE_F32) {
+                        ((float *) drow)[j] += grad0;
+                    } else {
+                        ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
+                    }
+                } else if (op == GGML_OP_POOL_AVG) {
+                    const float grad = grad0 / ka;
+
+                    for (int ky = 0; ky < k1; ++ky) {
+                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
+                            continue;
+                        }
+                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
+                        for (int kx = 0; kx < k0; ++kx) {
+                            int j = ix + kx;
+                            if (j < 0 || j >= dst->ne[0]) {
+                                continue;
+                            }
+
+                            if (dst->type == GGML_TYPE_F32) {
+                                ((float *) drow)[j] += grad;
+                            } else {
+                                ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad);
+                            }
+                        }
+                    }
+                } else {
+                    GGML_ASSERT(false);
+                }
+            }
+        }
+
+        cdata  += dst->nb[2];
+        cdataf += dst->nb[2];
+        splane += pa;
+    }
+}
+
+// ggml_compute_forward_upscale
+
+static void ggml_compute_forward_upscale_f32(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const float sf0 = (float)ne0/src0->ne[0];
+    const float sf1 = (float)ne1/src0->ne[1];
+    const float sf2 = (float)ne2/src0->ne[2];
+    const float sf3 = (float)ne3/src0->ne[3];
+
+    const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
+
+    if (mode == GGML_SCALE_MODE_NEAREST) {
+        for (int64_t i3 = 0; i3 < ne3; i3++) {
+            const int64_t i03 = i3 / sf3;
+            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
+                const int64_t i02 = i2 / sf2;
+                for (int64_t i1 = 0; i1 < ne1; i1++) {
+                    const int64_t i01 = i1 / sf1;
+                    for (int64_t i0 = 0; i0 < ne0; i0++) {
+                        const int64_t i00 = i0 / sf0;
+
+                        const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
+
+                        *y = *x;
+                    }
+                }
+            }
+        }
+    } else if (mode == GGML_SCALE_MODE_BILINEAR) {
+        // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
+        const float pixel_offset = 0.5f;
+
+        for (int64_t i3 = 0; i3 < ne3; i3++) {
+            const int64_t i03 = i3 / sf3;
+            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
+                const int64_t i02 = i2 / sf2;
+                for (int64_t i1 = 0; i1 < ne1; i1++) {
+                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
+                    int64_t y0 = (int64_t)floorf(y);
+                    int64_t y1 = y0 + 1;
+
+                    y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
+                    y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
+
+                    float dy = y - (float)y0;
+                    dy = std::max(0.0f, std::min(dy, 1.0f));
+
+                    for (int64_t i0 = 0; i0 < ne0; i0++) {
+                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
+                        int64_t x0 = (int64_t)floorf(x);
+                        int64_t x1 = x0 + 1;
+
+                        x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
+                        x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
+
+                        float dx = x - (float)x0;
+                        dx = std::max(0.0f, std::min(dx, 1.0f));
+
+                        // fetch the four surrounding pixel values and interpolate
+                        const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
+                        const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
+                        const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
+                        const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
+
+                        const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
+
+                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+                        *y_dst = val;
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("unsupported upscale mode");
+    }
+}
+
+void ggml_compute_forward_upscale(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_upscale_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+
+// ggml_compute_forward_pad
+
+static void ggml_compute_forward_pad_f32(
+    const ggml_compute_params * params,
+          ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    } else {
+                        dst_ptr[dst_idx] = 0;
+                    }
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_pad(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_pad_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_pad_reflect_1d
+
+void ggml_compute_forward_pad_reflect_1d(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int32_t * opts = (const int32_t *) dst->op_params;
+    const int p0 = opts[0];
+    const int p1 = opts[1];
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+                float * left  = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 +         p0*nb0);
+                float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
+
+                ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
+
+                for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0];   }
+                for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
+            }
+        }
+    }
+}
+
+// ggml_compute_forward_unpad
+
+static void ggml_compute_forward_unpad_f32(
+    const struct ggml_compute_params *params,
+    struct ggml_tensor *dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    }
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_unpad(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_unpad_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_arange
+
+static void ggml_compute_forward_arange_f32(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const float start = ggml_get_op_params_f32(dst, 0);
+    const float stop  = ggml_get_op_params_f32(dst, 1);
+    const float step  = ggml_get_op_params_f32(dst, 2);
+
+    const int64_t steps = (int64_t) ceilf((stop - start) / step);
+
+    GGML_ASSERT(ggml_nelements(dst) == steps);
+
+    for (int64_t i = ith; i < steps; i+= nth) {
+        float value = start + step * i;
+        ((float *)dst->data)[i] = value;
+    }
+}
+
+void ggml_compute_forward_arange(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+    switch (dst->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_arange_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_timestep_embedding_f32(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int dim = ggml_get_op_params_i32(dst, 0);
+    const int max_period = ggml_get_op_params_i32(dst, 1);
+
+    int half = dim / 2;
+
+    for (int64_t i = 0; i < ne00; i++) {
+        float * embed_data = (float *)((char *)  dst->data +  i*nb1);
+        for (int64_t j = ith; j < half; j += nth) {
+            float timestep = ((float *)src0->data)[i];
+            float freq = (float)expf(-logf(max_period) * j / half);
+            float arg = timestep * freq;
+            embed_data[j] = cosf(arg);
+            embed_data[j + half] = sinf(arg);
+        }
+        if (dim % 2 != 0 && ith == 0) {
+            embed_data[dim] = 0.f;
+        }
+    }
+}
+
+void ggml_compute_forward_timestep_embedding(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_timestep_embedding_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_argsort
+
+static void ggml_compute_forward_argsort_f32(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(nb0 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+    for (int64_t i = ith; i < nr; i += nth) {
+        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+        const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+        for (int64_t j = 0; j < ne0; j++) {
+            dst_data[j] = j;
+        }
+
+        // C doesn't have a functional sort, so we do a bubble sort instead
+        for (int64_t j = 0; j < ne0; j++) {
+            for (int64_t k = j + 1; k < ne0; k++) {
+                if ((order == GGML_SORT_ORDER_ASC  && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+                    (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+                    int32_t tmp = dst_data[j];
+                    dst_data[j] = dst_data[k];
+                    dst_data[k] = tmp;
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_argsort(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_argsort_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_flash_attn_ext
+
+static void ggml_compute_forward_flash_attn_ext_f16(
+        const ggml_compute_params * params,
+        const ggml_tensor * q,
+        const ggml_tensor * k,
+        const ggml_tensor * v,
+        const ggml_tensor * mask,
+        ggml_tensor * dst) {
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t DK = nek0;
+    const int64_t DV = nev0;
+    const int64_t N  = neq1;
+
+    GGML_ASSERT(ne0 == DV);
+    GGML_ASSERT(ne2 == N);
+
+    // input tensor rows must be contiguous
+    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+    GGML_ASSERT(neq0 == DK);
+    GGML_ASSERT(nek0 == DK);
+    GGML_ASSERT(nev0 == DV);
+
+    GGML_ASSERT(neq1 == N);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // broadcast factors
+    const int64_t rk2 = neq2/nek2;
+    const int64_t rk3 = neq3/nek3;
+
+    const int64_t rv2 = neq2/nev2;
+    const int64_t rv3 = neq3/nev3;
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq1*neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
+
+    const uint32_t n_head      = neq2;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    ggml_type    const k_vec_dot_type      = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
+    ggml_from_float_t const q_to_vec_dot   = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
+    ggml_vec_dot_t    const kq_vec_dot     = ggml_get_type_traits_cpu(k->type)->vec_dot;
+    ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
+
+    GGML_ASSERT((                            q_to_vec_dot) && "fattn: unsupported K-type");
+    GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float  ) && "fattn: unsupported V-type");
+
+    // loop over n_batch and n_head
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        const uint32_t h = iq2; // head index
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+        float S = 0.0f;      // sum
+        float M = -INFINITY; // maximum KQ value
+
+        float       * VKQ32 = (float       *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
+        float       * V32   =                 (VKQ32 + 1*DV); // (temporary) FP32 V buffer
+        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
+        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
+
+        if (v->type == GGML_TYPE_F16) {
+            memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
+        } else {
+            memset(VKQ32, 0, DV*sizeof(float));
+        }
+
+        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+
+        // k indices
+        const int ik3 = iq3 / rk3;
+        const int ik2 = iq2 / rk2;
+
+        // v indices
+        const int iv3 = iq3 / rv3;
+        const int iv2 = iq2 / rv2;
+
+        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
+        q_to_vec_dot(pq, Q_q, DK);
+
+        // online softmax / attention
+        // loop over n_kv and n_head_kv
+        // ref: https://arxiv.org/pdf/2112.05682.pdf
+        for (int64_t ic = 0; ic < nek1; ++ic) {
+            const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
+            if (mv == -INFINITY) {
+                continue;
+            }
+
+            float s; // KQ value
+
+            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
+            kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
+
+            s = s*scale; // scale KQ value
+
+            if (logit_softcap != 0.0f) {
+                s = logit_softcap*tanhf(s);
+            }
+
+            s += mv; // apply mask
+
+            const float Mold = M;
+
+            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
+            float vs = 1.0f; // post-softmax KQ value, expf(s - M)
+
+            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
+
+            if (v->type == GGML_TYPE_F16) {
+                if (s > M) {
+                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
+                    M = s;
+                    ms = expf(Mold - M);
+
+                    // V = V*expf(Mold - M)
+                    ggml_vec_scale_f16(DV, VKQ16, ms);
+                } else {
+                    // no new maximum, ms == 1.0f, vs != 1.0f
+                    vs = expf(s - M);
+                }
+
+                // V += v*expf(s - M)
+                ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
+            } else {
+                if (s > M) {
+                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
+                    M = s;
+                    ms = expf(Mold - M);
+
+                    // V = V*expf(Mold - M)
+                    ggml_vec_scale_f32(DV, VKQ32, ms);
+                } else {
+                    // no new maximum, ms == 1.0f, vs != 1.0f
+                    vs = expf(s - M);
+                }
+
+                // V += v*expf(s - M)
+                if (v_to_float) {
+                    v_to_float(v_data, V32, DV);
+                    ggml_vec_mad_f32(DV, VKQ32, V32, vs);
+                } else {
+                    // V is F32
+                    ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
+                }
+            }
+
+            S = S*ms + vs; // scale and increment sum with partial sum
+        }
+
+        if (v->type == GGML_TYPE_F16) {
+            for (int64_t d = 0; d < DV; ++d) {
+                VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
+            }
+        }
+
+        // V /= S
+        const float S_inv = 1.0f/S;
+        ggml_vec_scale_f32(DV, VKQ32, S_inv);
+
+        // dst indices
+        const int i1 = iq1;
+        const int i2 = iq2;
+        const int i3 = iq3;
+
+        // original
+        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+
+        // permute(0, 2, 1, 3)
+        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+    }
+}
+
+void ggml_compute_forward_flash_attn_ext(
+        const ggml_compute_params * params,
+        const ggml_tensor * q,
+        const ggml_tensor * k,
+        const ggml_tensor * v,
+        const ggml_tensor * mask,
+        ggml_tensor * dst) {
+    switch (dst->op_params[3]) {
+        case GGML_PREC_DEFAULT:
+        case GGML_PREC_F32:
+            {
+                // uses F32 accumulators
+                ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_flash_attn_back
+
+static void ggml_compute_forward_flash_attn_back_f32(
+        const ggml_compute_params * params,
+        const bool masked,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * q = dst->src[0];
+    const ggml_tensor * k = dst->src[1];
+    const ggml_tensor * v = dst->src[2];
+    const ggml_tensor * d = dst->src[3];
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+    const int64_t P = nek1 - N;
+    const int64_t M = P + N;
+
+    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
+    const int mxDM = MAX(D, Mup);
+
+    // GGML_ASSERT(ne0 == D);
+    // GGML_ASSERT(ne1 == N);
+    GGML_ASSERT(P >= 0);
+
+    GGML_ASSERT(nbq0 == sizeof(float));
+    GGML_ASSERT(nbk0 == sizeof(float));
+    GGML_ASSERT(nbv0 == sizeof(float));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev1 == D);
+    GGML_ASSERT(ned0 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nek1 == N + P);
+    GGML_ASSERT(nev1 == D);
+    GGML_ASSERT(ned1 == N);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (ith == 0) {
+        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
+    }
+    ggml_barrier(params->threadpool);
+
+    const int64_t elem_q = ggml_nelements(q);
+    const int64_t elem_k = ggml_nelements(k);
+
+    ggml_type result_type = dst->type;
+    GGML_ASSERT(ggml_blck_size(result_type) == 1);
+    const size_t tsize = ggml_type_size(result_type);
+
+    const size_t offs_q = 0;
+    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+
+    void * grad_q = (char *) dst->data;
+    void * grad_k = (char *) dst->data + offs_k;
+    void * grad_v = (char *) dst->data + offs_v;
+
+    const size_t nbgq1 = nb0*neq0;
+    const size_t nbgq2 = nb0*neq0*neq1;
+    const size_t nbgq3 = nb0*neq0*neq1*neq2;
+
+    const size_t nbgk1 = nb0*nek0;
+    const size_t nbgk2 = nb0*nek0*nek1;
+    const size_t nbgk3 = nb0*nek0*nek1*neq2;
+
+    const size_t nbgv1 = nb0*nev0;
+    const size_t nbgv2 = nb0*nev0*nev1;
+    const size_t nbgv3 = nb0*nev0*nev1*neq2;
+
+    // parallelize by k rows using ggml_vec_dot_f32
+
+    // total rows in k
+    const int nr = nek2*nek3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float scale = 1.0f/sqrtf(D);
+
+    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+    // how often k2 (and v2) is repeated in q2
+    int nrep = neq2/nek2;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int ik3 = ir/(nek2);
+        const int ik2 = ir - ik3*nek2;
+
+        const int iq3 = ik3;
+        const int id3 = ik3;
+        const int iv3 = ik3;
+        const int iv2 = ik2;
+
+        for (int irep = 0; irep < nrep; ++irep) {
+            const int iq2 = ik2 + irep*nek2;
+            const int id2 = iq2;
+
+            // (ik2 + irep*nek2) % nek2 == ik2
+            for (int iq1 = 0; iq1 < neq1; ++iq1) {
+                const int id1 = iq1;
+
+                // not sure about CACHE_LINE_SIZE_F32..
+                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
+                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
+                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
+
+                for (int i = M; i < Mup; ++i) {
+                    S[i] = -INFINITY;
+                }
+
+                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    // k indices
+                    const int ik1 = ic;
+
+                    // S indices
+                    const int i1 = ik1;
+
+                    ggml_vec_dot_f32(neq0,
+                            S + i1, 0,
+                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
+                }
+
+                // scale
+                ggml_vec_scale_f32(masked_begin, S, scale);
+
+                for (int64_t i = masked_begin; i < M; i++) {
+                    S[i] = -INFINITY;
+                }
+
+                // softmax
+                // exclude known -INF S[..] values from max and loop
+                // dont forget to set their SM values to zero
+                {
+                    float max = -INFINITY;
+                    ggml_vec_max_f32(masked_begin, &max, S);
+
+                    ggml_float sum = 0.0;
+                    {
+#ifdef GGML_SOFT_MAX_ACCELERATE
+                        max = -max;
+                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
+                        vvexpf(SM, SM, &Mup);
+                        ggml_vec_sum_f32(Mup, &sum, SM);
+#else
+                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
+#endif
+                    }
+
+                    assert(sum > 0.0);
+
+                    sum = 1.0/sum;
+                    ggml_vec_scale_f32(masked_begin, SM, sum);
+
+                }
+
+                // step-by-step explanation
+                {
+                    // forward-process                    shape      grads from backward process
+                    // parallel_for ik2,ik3:
+                    //  for irep:
+                    //   iq2 = ik2 + irep*nek2
+                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]
+                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
+                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]
+                    //   for iq1:
+                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
+                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
+                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
+                    //    S0     = -Inf                   [D,1,1,1]
+                    //   ~S1[i]  = dot(kcur[:D,i], qcur)
+                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
+                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
+                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
+                    //   ~S5[i]  = dot(vcur[:,i], S4)
+                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]
+                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
+                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
+                    // dst                               backward-/ grad[dst]                 = d
+                    //
+                    // output gradients with their dependencies:
+                    //
+                    // grad[kcur] = grad[S1].T @ qcur
+                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    // grad[S4]   = grad[S5] @ vcur
+                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
+                    // grad[qcur] = grad[S1]   @ kcur
+                    // grad[vcur] = grad[S5].T @ S4
+                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+                    //
+                    // in post-order:
+                    //
+                    // S1         = qcur @ kcur.T
+                    // S2         = S1 * scale
+                    // S3         = diag_mask_inf(S2, P)
+                    // S4         = softmax(S3)
+                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
+                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                    // grad[qcur] = grad[S1]   @ kcur
+                    // grad[kcur] = grad[S1].T @ qcur
+                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+                    //
+                    // using less variables (SM=S4):
+                    //
+                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
+                    // SM            = softmax(S)
+                    // S             = d[:D,iq1,iq2,iq3] @ vcur
+                    // dot_SM_gradSM = dot(SM, S)
+                    // S             = SM * (S - dot(SM, S))
+                    // S             = diag_mask_zero(S, P) * scale
+                    //
+                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
+                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur
+                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM
+                }
+
+                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+                // for ic:
+                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
+                // exclude known future zero S[..] values from operation
+                ggml_vec_set_f32(masked_begin, S, 0);
+                for (int64_t ic = 0; ic < D; ++ic) {
+                    ggml_vec_mad_f32(masked_begin,
+                            S,
+                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),
+                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
+                }
+
+                // S = SM * (S - dot(SM, S))
+                float dot_SM_gradSM = 0;
+                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
+                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
+                ggml_vec_mul_f32 (masked_begin, S, S, SM);
+
+                // S = diag_mask_zero(S, P) * scale
+                // already done by above ggml_vec_set_f32
+
+                // exclude known zero S[..] values from operation
+                ggml_vec_scale_f32(masked_begin, S, scale);
+
+                // S    shape [M,1]
+                // SM   shape [M,1]
+                // kcur shape [D,M]
+                // qcur shape [D,1]
+                // vcur shape [M,D]
+
+                // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
+                // for ic:
+                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
+                // exclude known zero S[..] values from loop
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    ggml_vec_mad_f32(D,
+                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),
+                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),
+                            S[ic]);
+                }
+
+                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
+                // for ic:
+                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
+                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
+                // exclude known zero S[..] values from loop
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    ggml_vec_mad_f32(D,
+                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),
+                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),
+                            S[ic]);
+                }
+
+                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM
+                // for ic:
+                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
+                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]
+                // exclude known zero SM[..] values from mad
+                for (int64_t ic = 0; ic < D; ++ic) {
+                    ggml_vec_mad_f32(masked_begin,
+                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
+                            SM,
+                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_flash_attn_back(
+        const ggml_compute_params * params,
+        const bool masked,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * q = dst->src[0];
+
+    switch (q->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_ssm_conv
+
+static void ggml_compute_forward_ssm_conv_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // conv_x
+    const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc  = src1->ne[0]; // d_conv
+    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
+    const int nr  = src0->ne[1]; // d_inner
+    const int n_t =  dst->ne[1]; // tokens per sequence
+    const int n_s =  dst->ne[2]; // number of sequences in the batch
+
+    GGML_ASSERT( dst->ne[0] == nr);
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+    const int ir  = ir1 - ir0;
+
+    for (int i3 = 0; i3 < n_s; ++i3) {
+        for (int i2 = 0; i2 < n_t; ++i2) {
+            // {d_conv - 1 + n_t, d_inner, n_seqs}
+            // sliding window
+            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
+            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
+            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
+
+            // TODO: transpose the output for smaller strides for big batches?
+            // d_inner
+            for (int i1 = 0; i1 < ir; ++i1) {
+                // rowwise dot product
+                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
+                float sumf = 0.0f;
+
+                // d_conv
+                for (int i0 = 0; i0 < nc; ++i0) {
+                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
+                }
+                x[i1] = sumf;
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_ssm_conv(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    switch (dst->src[0]->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_ssm_conv_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_ssm_scan
+
+static void ggml_compute_forward_ssm_scan_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // s
+    const ggml_tensor * src1 = dst->src[1]; // x
+    const ggml_tensor * src2 = dst->src[2]; // dt
+    const ggml_tensor * src3 = dst->src[3]; // A
+    const ggml_tensor * src4 = dst->src[4]; // B
+    const ggml_tensor * src5 = dst->src[5]; // C
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nc  = src0->ne[0]; // d_state
+    const int64_t nr  = src0->ne[1]; // d_inner
+    const int64_t n_t = src1->ne[1]; // number of tokens per sequence
+    const int64_t n_s = src0->ne[2]; // number of sequences in the batch
+
+    GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+    GGML_ASSERT(src2->nb[0] == sizeof(float));
+    GGML_ASSERT(src3->nb[0] == sizeof(float));
+    GGML_ASSERT(src4->nb[0] == sizeof(float));
+    GGML_ASSERT(src5->nb[0] == sizeof(float));
+    // required for the dot product between s and C
+    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+    // required for per-sequence offsets for states
+    GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
+    // required to get correct offset for state destination (i.e. src1->nb[3])
+    GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+    const int ir  = ir1 - ir0;
+
+    for (int i3 = 0; i3 < n_s; ++i3) {
+        for (int i2 = 0; i2 < n_t; ++i2) {
+            const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
+            const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+            const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
+            const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
+            const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
+            const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
+                  float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+                  float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
+
+            // use the output as the source for the next token-wise iterations
+            if (i2 > 0) { s0 = s; }
+
+            // d_inner
+            for (int i1 = 0; i1 < ir; ++i1) {
+                // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
+                float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
+                float x_dt = x[i1] * dt_soft_plus;
+                float sumf = 0.0f;
+                // d_state
+                for (int i0 = 0; i0 < nc; ++i0) {
+                    int i = i0 + i1*nc;
+                    // state = prev_state * dA + dB * x
+                    float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+                    // y = rowwise_dotprod(state, C)
+                    sumf += state * C[i0];
+                    s[i] = state;
+                }
+                y[i1] = sumf;
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_ssm_scan(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    switch (dst->src[0]->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_ssm_scan_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_win_part
+
+static void ggml_compute_forward_win_part_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    GGML_UNUSED(params);
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+
+    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
+
+    assert(ne00 == ne0);
+    assert(ne3  == nep0*nep1);
+
+    // TODO: optimize / multi-thread
+    for (int py = 0; py < nep1; ++py) {
+        for (int px = 0; px < nep0; ++px) {
+            const int64_t i3 = py*nep0 + px;
+            for (int64_t i2 = 0; i2 < ne2; ++i2) {
+                for (int64_t i1 = 0; i1 < ne1; ++i1) {
+                    for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                        const int64_t i02 = py*w + i2;
+                        const int64_t i01 = px*w + i1;
+                        const int64_t i00 = i0;
+
+                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;
+                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;
+
+                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
+                            ((float *) dst->data)[i] = 0.0f;
+                        } else {
+                            ((float *) dst->data)[i] = ((float *) src0->data)[j];
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_win_part(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_win_part_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_win_unpart
+
+static void ggml_compute_forward_win_unpart_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    GGML_UNUSED(params);
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+
+    const int32_t w = ((const int32_t *)(dst->op_params))[0];
+
+    // padding
+    const int px = (w - ne1%w)%w;
+    //const int py = (w - ne2%w)%w;
+
+    const int npx = (px + ne1)/w;
+    //const int npy = (py + ne2)/w;
+
+    assert(ne0 == ne00);
+
+    // TODO: optimize / multi-thread
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = 0; i1 < ne1; ++i1) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int ip2 = i2/w;
+                const int ip1 = i1/w;
+
+                const int64_t i02 = i2%w;
+                const int64_t i01 = i1%w;
+                const int64_t i00 = i0;
+
+                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
+                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;
+
+                ((float *) dst->data)[j] = ((float *) src0->data)[i];
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_win_unpart(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_win_unpart_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+//gmml_compute_forward_unary
+
+void ggml_compute_forward_unary(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_unary_op op = ggml_get_unary_op(dst);
+
+    switch (op) {
+        case GGML_UNARY_OP_ABS:
+            {
+                ggml_compute_forward_abs(params, dst);
+            } break;
+        case GGML_UNARY_OP_SGN:
+            {
+                ggml_compute_forward_sgn(params, dst);
+            } break;
+        case GGML_UNARY_OP_NEG:
+            {
+                ggml_compute_forward_neg(params, dst);
+            } break;
+        case GGML_UNARY_OP_STEP:
+            {
+                ggml_compute_forward_step(params, dst);
+            } break;
+        case GGML_UNARY_OP_TANH:
+            {
+                ggml_compute_forward_tanh(params, dst);
+            } break;
+        case GGML_UNARY_OP_ELU:
+            {
+                ggml_compute_forward_elu(params, dst);
+            } break;
+        case GGML_UNARY_OP_RELU:
+            {
+                ggml_compute_forward_relu(params, dst);
+            } break;
+        case GGML_UNARY_OP_SIGMOID:
+            {
+                ggml_compute_forward_sigmoid(params, dst);
+            } break;
+        case GGML_UNARY_OP_GELU:
+            {
+                ggml_compute_forward_gelu(params, dst);
+            } break;
+        case GGML_UNARY_OP_GELU_QUICK:
+            {
+                ggml_compute_forward_gelu_quick(params, dst);
+            } break;
+        case GGML_UNARY_OP_SILU:
+            {
+                ggml_compute_forward_silu(params, dst);
+            } break;
+        case GGML_UNARY_OP_HARDSWISH:
+            {
+                ggml_compute_forward_hardswish(params, dst);
+            } break;
+        case GGML_UNARY_OP_HARDSIGMOID:
+            {
+                ggml_compute_forward_hardsigmoid(params, dst);
+            } break;
+        case GGML_UNARY_OP_EXP:
+            {
+                ggml_compute_forward_exp(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_get_rel_pos
+
+static void ggml_compute_forward_get_rel_pos_f16(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    GGML_UNUSED(params);
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int64_t w = ne1;
+
+    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
+    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = 0; i1 < ne1; ++i1) {
+            const int64_t pos = (w - i1 - 1) + i2;
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_get_rel_pos(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_get_rel_pos_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_add_rel_pos
+
+static void ggml_compute_forward_add_rel_pos_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
+    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
+    if (!inplace) {
+        if (params->ith == 0) {
+            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
+
+    float * src1_data = (float *) src1->data;
+    float * src2_data = (float *) src2->data;
+    float * dst_data  = (float *) dst->data;
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // total patches in dst
+    const int np = ne13;
+
+    // patches per thread
+    const int dp = (np + nth - 1)/nth;
+
+    // patch range for this thread
+    const int ip0 = dp*ith;
+    const int ip1 = MIN(ip0 + dp, np);
+
+    for (int64_t i13 = ip0; i13 < ip1; ++i13) {
+        for (int64_t i12 = 0; i12 < ne12; ++i12) {
+            for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
+                for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                    const int64_t jp0  = jp1 + i10;
+                    const float src1_e = src1_data[jp0];
+                    const float src2_e = src2_data[jp0];
+
+                    const int64_t jdh = jp0 * ne10;
+                    const int64_t jdw = jdh - (ne10 - 1) * i10;
+
+                    for (int64_t j = 0; j < ne10; ++j) {
+                        dst_data[jdh + j     ] += src2_e;
+                        dst_data[jdw + j*ne10] += src1_e;
+                    }
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_add_rel_pos(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add_rel_pos_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_rwkv_wkv6
+
+static void ggml_compute_forward_rwkv_wkv6_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[1];
+    const int64_t n_seqs = dst->src[5]->ne[1];
+    const int64_t head_size = C / HEADS;
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * k =          (float *) dst->src[0]->data;
+    float * v =          (float *) dst->src[1]->data;
+    float * r =          (float *) dst->src[2]->data;
+    float * time_faaaa = (float *) dst->src[3]->data;
+    float * time_decay = (float *) dst->src[4]->data;
+
+    size_t t_stride = HEADS * head_size; // Same to C
+
+    size_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    size_t h_stride_2d = head_size * head_size;
+
+    if (ith == 0) {
+        memset(dst_data, 0, T * C * sizeof(float));
+    }
+    ggml_barrier(params->threadpool);
+
+
+    #if defined(__AVX__) && !defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x8
+        #define GGML_F32X_SET1 GGML_F32x8_SET1
+        #define GGML_F32X_LOAD GGML_F32x8_LOAD
+        #define GGML_F32X_STORE GGML_F32x8_STORE
+        #define GGML_F32X_MUL GGML_F32x8_MUL
+        #define GGML_F32X_FMA GGML_F32x8_FMA
+        #define WKV_VECTOR_SIZE 8
+    #elif defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x16
+        #define GGML_F32X_SET1 GGML_F32x16_SET1
+        #define GGML_F32X_LOAD GGML_F32x16_LOAD
+        #define GGML_F32X_STORE GGML_F32x16_STORE
+        #define GGML_F32X_MUL GGML_F32x16_MUL
+        #define GGML_F32X_FMA GGML_F32x16_FMA
+        #define WKV_VECTOR_SIZE 16
+    #elif defined(__ARM_NEON) && defined(__aarch64__)
+        #define GGML_F32X GGML_F32x4
+        #define GGML_F32X_SET1 GGML_F32x4_SET1
+        #define GGML_F32X_LOAD GGML_F32x4_LOAD
+        #define GGML_F32X_STORE GGML_F32x4_STORE
+        #define GGML_F32X_MUL GGML_F32x4_MUL
+        #define GGML_F32X_FMA GGML_F32x4_FMA
+        #define WKV_VECTOR_SIZE 4
+    #endif
+
+    #ifdef WKV_VECTOR_SIZE
+        const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
+
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_i_offset = h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float r_val = r[t_h_i_offset];
+                    float time_faaaa_val = time_faaaa[h_i_offset];
+                    float time_decay_val = time_decay[t_h_i_offset];
+
+                    // Broadcast scalar values to vectors
+                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);
+                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
+                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
+
+                    for (int64_t j = 0; j < vec_count; j++) {
+                        size_t base_j = j * WKV_VECTOR_SIZE;
+                        size_t t_h_j_offset = t_h_offset + base_j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+                        // Load x elements at once
+                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+                        // Compute kv = v * k
+                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+                        // Compute temp = kv * time_faaaa + prev_state
+                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
+
+                        // Update dst: dst += temp * r
+                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
+                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+                        // Update state: state = prev_state * time_decay + kv
+                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
+                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
+                    }
+
+                    // Handle remaining elements, this will not be used.
+                    for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
+                        dst_data[t_h_j_offset] += temp_val * r_val;
+                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+                    }
+                }
+            }
+        }
+
+    #else
+        // basically fused operations:
+        // dst = r @ (time_faaaa * (k @ v) + state),
+        // state = time_decay * state + (k @ v),
+        // recursive through each token
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_i_offset = h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float r_val = r[t_h_i_offset];
+                    float time_faaaa_val = time_faaaa[h_i_offset];
+                    // RWKV v6: different time_decay for each token.
+                    float time_decay_val = time_decay[t_h_i_offset];
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
+                        dst_data[t_h_j_offset] += temp_val * r_val;
+                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+                    }
+                }
+            }
+        }
+    #endif
+}
+
+
+void ggml_compute_forward_rwkv_wkv6(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rwkv_wkv6_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_gla
+
+static void ggml_compute_forward_gla_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[1];
+    const int64_t n_seqs = dst->src[4]->ne[1];
+    const int64_t head_size = C / HEADS;
+    const float scale = ggml_get_op_params_f32(dst, 0);
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * k = (float *) dst->src[0]->data;
+    float * v = (float *) dst->src[1]->data;
+    float * q = (float *) dst->src[2]->data;
+    float * g = (float *) dst->src[3]->data;
+
+    size_t t_stride = HEADS * head_size; // Same to C
+
+    size_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    size_t h_stride_2d = head_size * head_size;
+
+    if (ith == 0) {
+        memset(dst_data, 0, T * C * sizeof(float));
+    }
+    ggml_barrier(params->threadpool);
+
+
+    #if defined(__AVX__) && !defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x8
+        #define GGML_F32X_SET1 GGML_F32x8_SET1
+        #define GGML_F32X_LOAD GGML_F32x8_LOAD
+        #define GGML_F32X_STORE GGML_F32x8_STORE
+        #define GGML_F32X_MUL GGML_F32x8_MUL
+        #define GGML_F32X_FMA GGML_F32x8_FMA
+        #define GLA_VECTOR_SIZE 8
+    #elif defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x16
+        #define GGML_F32X_SET1 GGML_F32x16_SET1
+        #define GGML_F32X_LOAD GGML_F32x16_LOAD
+        #define GGML_F32X_STORE GGML_F32x16_STORE
+        #define GGML_F32X_MUL GGML_F32x16_MUL
+        #define GGML_F32X_FMA GGML_F32x16_FMA
+        #define GLA_VECTOR_SIZE 16
+    #elif defined(__ARM_NEON) && defined(__aarch64__)
+        #define GGML_F32X GGML_F32x4
+        #define GGML_F32X_SET1 GGML_F32x4_SET1
+        #define GGML_F32X_LOAD GGML_F32x4_LOAD
+        #define GGML_F32X_STORE GGML_F32x4_STORE
+        #define GGML_F32X_MUL GGML_F32x4_MUL
+        #define GGML_F32X_FMA GGML_F32x4_FMA
+        #define GLA_VECTOR_SIZE 4
+    #endif
+
+    #ifdef GLA_VECTOR_SIZE
+        const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
+
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    // Broadcast scalar values to vectors
+                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
+                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
+
+                    for (int64_t j = 0; j < vec_count; j++) {
+                        size_t base_j = j * GLA_VECTOR_SIZE;
+                        size_t t_h_j_offset = t_h_offset + base_j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+                        // Load x elements at once
+                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+                        // Compute kv = v * k
+                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+                        // Compute temp = prev_state * g + kv
+                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
+
+                        // Update dst: dst += temp * q
+                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
+                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+                        // Update state
+                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
+                    }
+
+                    // Handle remaining elements, this will not be used.
+                    for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val + prev_state_val * g_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+
+    #else
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = prev_state_val * g_val + kv_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+    #endif
+}
+
+
+void ggml_compute_forward_gla(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gla_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_rwkv_wkv7
+
+static void ggml_compute_forward_rwkv_wkv7_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[1];
+    const int64_t n_seqs = dst->src[6]->ne[1];
+    const int64_t head_size = C / HEADS;
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * r = (float *) dst->src[0]->data;
+    float * w = (float *) dst->src[1]->data;
+    float * k = (float *) dst->src[2]->data;
+    float * v = (float *) dst->src[3]->data;
+    float * a = (float *) dst->src[4]->data;
+    float * b = (float *) dst->src[5]->data;
+
+    int64_t t_stride = HEADS * head_size; // Same to C
+
+    int64_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    int64_t h_stride_2d = head_size * head_size;
+
+    #if defined(GGML_SIMD)
+        for (int64_t t = 0; t < T; t++) {
+            int64_t t_offset = t * t_stride;
+            int64_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                int64_t h_offset = h * h_stride;
+                int64_t t_h_offset = t_offset + h_offset;
+                int64_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t ii = 0; ii < head_size; ii++) {
+                    int64_t t_h_i_offset = t_h_offset + ii;
+                    int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
+
+                    GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
+
+                    float sa = 0;
+                    {
+                        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+                        GGML_F32_VEC ax[GGML_F32_ARR];
+                        GGML_F32_VEC ay[GGML_F32_ARR];
+                        for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
+                            for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
+                                ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
+                                ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
+                                sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
+                            }
+                        }
+                        GGML_F32_VEC_REDUCE(sa, sum);
+                    }
+
+                    GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
+
+                    int64_t j = 0;
+                    GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+                    for (; j < head_size; j += GGML_F32_STEP) {
+                        for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
+                            int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
+                            int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
+
+                            GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
+                            GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
+                            GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
+                            GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
+
+                            k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
+
+                            GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
+                            // kv + s * decay + sa * b
+                            state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
+                            state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
+                            GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
+
+                            result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
+                        }
+                    }
+                    GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
+
+                    // There shouldn't be left-overs though.
+                    for (; j < head_size; j++) {
+                        int64_t t_h_j_offset = t_h_offset + j;
+                        int64_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float r_val = r[t_h_j_offset];
+                        float w_val = w[t_h_j_offset];
+                        float k_val = k[t_h_j_offset];
+                        float b_val = b[t_h_j_offset];
+                        float kv_val = v[t_h_i_offset] * k_val;
+
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
+                        dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
+                    }
+                }
+            }
+        }
+    #else
+        for (int64_t t = 0; t < T; t++) {
+            int64_t t_offset = t * t_stride;
+            int64_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                int64_t h_offset = h * h_stride;
+                int64_t t_h_offset = t_offset + h_offset;
+                int64_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    int64_t t_h_i_offset = t_h_offset + i;
+                    int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float v_val = v[t_h_i_offset];
+
+                    float sa = 0, result = 0;
+                    for (int64_t j = 0; j < head_size; j++) {
+                        sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
+                    }
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        int64_t t_h_j_offset = t_h_offset + j;
+                        int64_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float r_val = r[t_h_j_offset];
+                        float w_val = w[t_h_j_offset];
+                        float k_val = k[t_h_j_offset];
+                        float b_val = b[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
+                        result += state_cur[h_2d_i_j_offset] * r_val;
+                    }
+                    dst_data[t_h_i_offset] = result;
+                }
+            }
+        }
+    #endif
+}
+
+
+void ggml_compute_forward_rwkv_wkv7(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rwkv_wkv7_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_map_custom1
+
+void ggml_compute_forward_map_custom1(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * a = dst->src[0];
+
+    struct ggml_map_custom1_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, a, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_map_custom2
+
+void ggml_compute_forward_map_custom2(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * a = dst->src[0];
+    const ggml_tensor * b = dst->src[1];
+
+    struct ggml_map_custom2_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, a, b, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_map_custom3
+
+void ggml_compute_forward_map_custom3(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * a = dst->src[0];
+    const ggml_tensor * b = dst->src[1];
+    const ggml_tensor * c = dst->src[2];
+
+    struct ggml_map_custom3_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_custom
+
+void ggml_compute_forward_custom(
+    const struct ggml_compute_params * params,
+          struct ggml_tensor * dst) {
+
+    struct ggml_custom_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_cross_entropy_loss
+
+static void ggml_compute_forward_cross_entropy_loss_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    // TODO: handle transposed/permuted matrices
+    const int64_t nc = src0->ne[0];
+    const int64_t nr = ggml_nrows(src0);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    float * sums =  (float *) params->wdata;
+    float * st   = ((float *) params->wdata) + nth + ith*nc;
+    float sum_thread = 0.0f;
+
+    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i1 = ir0; i1 < ir1; ++i1) {
+        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
+        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(s0[i]));
+            assert(!isnan(s1[i]));
+        }
+#endif
+
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, s0);
+        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
+        assert(sum_softmax >= 0.0);
+
+        ggml_vec_add1_f32(nc, st, st, -sum_softmax);
+        ggml_vec_mul_f32(nc, st, st, s1);
+
+        float sum_st = 0.0f;
+        ggml_vec_sum_f32(nc, &sum_st, st);
+        sum_thread += sum_st;
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            assert(!isnan(st[i]));
+            assert(!isinf(st[i]));
+        }
+#endif
+    }
+    sums[ith] = sum_thread;
+    ggml_barrier(params->threadpool);
+
+    if (ith == 0) {
+        float * dp = (float *) dst->data;
+        ggml_vec_sum_f32(nth, dp, sums);
+        dp[0] *= -1.0f / (float) nr;
+    }
+}
+
+void ggml_compute_forward_cross_entropy_loss(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cross_entropy_loss_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_cross_entropy_loss_back
+
+static void ggml_compute_forward_cross_entropy_loss_back_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * grad  = dst->src[0]; // gradient of forward pass output
+    const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
+    const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(src0f));
+    GGML_ASSERT(ggml_is_contiguous(src1f));
+    GGML_ASSERT(ggml_is_contiguous(grad));
+    GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
+
+    const int64_t ith = params->ith;
+    const int64_t nth = params->nth;
+
+    // TODO: handle transposed/permuted matrices
+    const int64_t nc = src0f->ne[0];
+    const int64_t nr = ggml_nrows(src0f);
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
+
+    for (int64_t i1 = ir0; i1 < ir1; i1++) {
+        float       * ds0 = (float       *)((char       *) dst->data   + i1*dst->nb[1]);
+        const float * s0  = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
+        const float * s1  = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(s0[i]));
+            assert(!isnan(s1[i]));
+        }
+#endif
+
+        // soft_max
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, s0);
+        const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
+        assert(sum > 0.0);
+        ggml_vec_scale_f32(nc, ds0, 1.0/sum);
+
+        // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
+        ggml_vec_sub_f32(nc, ds0, ds0, s1);
+        ggml_vec_scale_f32(nc, ds0, d_by_nr);
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            assert(!isnan(ds0[i]));
+            assert(!isinf(ds0[i]));
+        }
+#endif
+    }
+}
+
+void ggml_compute_forward_cross_entropy_loss_back(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_opt_step_adamw_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0         = dst->src[0];
+    const ggml_tensor * src0_grad    = dst->src[1];
+    const ggml_tensor * src0_grad_m  = dst->src[2];
+    const ggml_tensor * src0_grad_v  = dst->src[3];
+    const ggml_tensor * adamw_params = dst->src[4];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
+    const float alpha  = adamw_params_ptr[0];
+    const float beta1  = adamw_params_ptr[1];
+    const float beta2  = adamw_params_ptr[2];
+    const float eps    = adamw_params_ptr[3];
+    const float wd     = adamw_params_ptr[4];
+    const float beta1h = adamw_params_ptr[5];
+    const float beta2h = adamw_params_ptr[6];
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir/(ne02*ne01);
+        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
+
+        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
+        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
+        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
+        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
+
+        for (int i00 = 0; i00 < ne00; ++i00) {
+            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
+            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
+
+            const float mh =       m[i00]*beta1h;
+            const float vh = sqrtf(v[i00]*beta2h) + eps;
+
+            // The weight decay is applied independently of the Adam momenta m and v.
+            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
+            // See: https://arxiv.org/pdf/1711.05101v3.pdf
+            w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
+        }
+    }
+}
+
+void ggml_compute_forward_opt_step_adamw(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_opt_step_adamw_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.h b/ml/backend/ggml/ggml/src/ggml-cpu/ops.h
new file mode 100644
index 000000000..3eca1cf86
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.h
@@ -0,0 +1,110 @@
+#pragma once
+
+#include "ggml.h"
+
+//
+// cache line
+//
+
+#if defined(__cpp_lib_hardware_interference_size)
+#define CACHE_LINE_SIZE std::hardware_destructive_interference_size
+#else
+#if defined(__POWER9_VECTOR__)
+#define CACHE_LINE_SIZE 128
+#elif defined(__VXE__) || defined(__VXE2__)
+#define CACHE_LINE_SIZE 256
+#else
+#define CACHE_LINE_SIZE 64
+#endif
+#endif
+
+static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_repeat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_repeat_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_concat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_l2_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_out_prod(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_scale(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_set(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_cpy(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_cont(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_reshape(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_view(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_permute(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_soft_max(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_soft_max_ext_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_rope(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_rope_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_pool_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_unpad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_flash_attn_ext(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * q,
+    const struct ggml_tensor * k,
+    const struct ggml_tensor * v,
+    const struct ggml_tensor * mask,
+    struct ggml_tensor * dst);
+void ggml_compute_forward_flash_attn_back(
+        const struct ggml_compute_params * params,
+        const bool masked,
+        struct ggml_tensor * dst);
+void ggml_compute_forward_ssm_conv(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_custom(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h b/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h
new file mode 100644
index 000000000..04d10cec2
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h
@@ -0,0 +1,892 @@
+#pragma once
+
+#include "ggml-cpu-impl.h"
+
+//
+// simd mappings
+//
+
+// we define a common set of C macros which map to specific intrinsics based on the current architecture
+// we then implement the fundamental computation operations below using only these macros
+// adding support for new architectures requires to define the corresponding SIMD macros
+//
+// GGML_F32_STEP / GGML_F16_STEP
+//   number of elements to process in a single step
+//
+// GGML_F32_EPR / GGML_F16_EPR
+//   number of elements to fit in a single register
+//
+
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
+
+#define GGML_SIMD
+
+// F32 NEON
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              float32x4_t
+#define GGML_F32x4_ZERO         vdupq_n_f32(0.0f)
+#define GGML_F32x4_SET1(x)      vdupq_n_f32(x)
+#define GGML_F32x4_LOAD         vld1q_f32
+#define GGML_F32x4_STORE        vst1q_f32
+#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
+#define GGML_F32x4_ADD          vaddq_f32
+#define GGML_F32x4_MUL          vmulq_f32
+#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
+#define GGML_F32x4_REDUCE(res, x)                       \
+{                                                       \
+    int offset = GGML_F32_ARR >> 1;                     \
+    for (int i = 0; i < offset; ++i) {                  \
+        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \
+    }                                                   \
+    offset >>= 1;                                       \
+    for (int i = 0; i < offset; ++i) {                  \
+        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \
+    }                                                   \
+    offset >>= 1;                                       \
+    for (int i = 0; i < offset; ++i) {                  \
+        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \
+    }                                                   \
+    (res) = (ggml_float) GGML_F32x4_REDUCE_ONE((x)[0]); \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    #define GGML_F16_STEP 32
+    #define GGML_F16_EPR  8
+
+    #define GGML_F16x8              float16x8_t
+    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)
+    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)
+    #define GGML_F16x8_LOAD(x)      vld1q_f16((const __fp16 *)(x))
+    #define GGML_F16x8_STORE        vst1q_f16
+    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
+    #define GGML_F16x8_ADD          vaddq_f16
+    #define GGML_F16x8_MUL          vmulq_f16
+    #define GGML_F16x8_REDUCE(res, x)                               \
+    do {                                                            \
+        int offset = GGML_F16_ARR >> 1;                             \
+        for (int i = 0; i < offset; ++i) {                          \
+            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
+        }                                                           \
+        offset >>= 1;                                               \
+        for (int i = 0; i < offset; ++i) {                          \
+            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
+        }                                                           \
+        offset >>= 1;                                               \
+        for (int i = 0; i < offset; ++i) {                          \
+            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
+        }                                                           \
+        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
+        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
+        (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
+    } while (0)
+
+    #define GGML_F16_VEC                GGML_F16x8
+    #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
+    #define GGML_F16_VEC_SET1           GGML_F16x8_SET1
+    #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), (r)[i])
+    #define GGML_F16_VEC_FMA            GGML_F16x8_FMA
+    #define GGML_F16_VEC_ADD            GGML_F16x8_ADD
+    #define GGML_F16_VEC_MUL            GGML_F16x8_MUL
+    #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE
+#else
+    // if FP16 vector arithmetic is not supported, we use FP32 instead
+    // and take advantage of the vcvt_ functions to convert to/from FP16
+
+    #define GGML_F16_STEP 16
+    #define GGML_F16_EPR  4
+
+    #define GGML_F32Cx4              float32x4_t
+    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)
+    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)
+    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16((const __fp16 *)(x)))
+    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))
+    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
+    #define GGML_F32Cx4_ADD          vaddq_f32
+    #define GGML_F32Cx4_MUL          vmulq_f32
+    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE
+
+    #define GGML_F16_VEC                GGML_F32Cx4
+    #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO
+    #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1
+    #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i])
+    #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA
+    #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD
+    #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL
+    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE
+#endif
+
+#elif defined(__AVX512F__)
+
+#define GGML_SIMD
+
+// F32 AVX512
+
+#define GGML_F32_STEP 64
+#define GGML_F32_EPR  16
+
+#define GGML_F32x16         __m512
+#define GGML_F32x16_ZERO    _mm512_setzero_ps()
+#define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
+#define GGML_F32x16_LOAD    _mm512_loadu_ps
+#define GGML_F32x16_STORE   _mm512_storeu_ps
+// _mm512_fmadd_ps is defined in AVX512F so no guard is required
+#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
+#define GGML_F32x16_ADD     _mm512_add_ps
+#define GGML_F32x16_MUL     _mm512_mul_ps
+#define GGML_F32x16_REDUCE(res, x)                                    \
+do {                                                                  \
+    int offset = GGML_F32_ARR >> 1;                                   \
+    for (int i = 0; i < offset; ++i) {                                \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
+    }                                                                 \
+    offset >>= 1;                                                     \
+    for (int i = 0; i < offset; ++i) {                                \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
+    }                                                                 \
+    offset >>= 1;                                                     \
+    for (int i = 0; i < offset; ++i) {                                \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
+    }                                                                 \
+    res = (ggml_float) _mm512_reduce_add_ps(x[0]);                    \
+} while (0)
+
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x16
+#define GGML_F32_VEC_ZERO   GGML_F32x16_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x16_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x16_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x16_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x16_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x16_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x16_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
+
+// F16 AVX512
+
+// F16 AVX
+
+#define GGML_F16_STEP 64
+#define GGML_F16_EPR  16
+
+// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
+
+#define GGML_F32Cx16             __m512
+#define GGML_F32Cx16_ZERO        _mm512_setzero_ps()
+#define GGML_F32Cx16_SET1(x)     _mm512_set1_ps(x)
+
+// unlike  _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
+// so F16C guard isn't required
+#define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
+#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
+
+#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
+#define GGML_F32Cx16_ADD         _mm512_add_ps
+#define GGML_F32Cx16_MUL         _mm512_mul_ps
+#define GGML_F32Cx16_REDUCE(res, x)                               \
+do {                                                              \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    res = (ggml_float) _mm512_reduce_add_ps(x[0]);                \
+} while (0)
+
+#define GGML_F16_VEC                GGML_F32Cx16
+#define GGML_F16_VEC_ZERO           GGML_F32Cx16_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx16_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx16_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx16_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx16_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx16_MUL
+
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx16_REDUCE
+#elif defined(__AVX__)
+
+#define GGML_SIMD
+
+// F32 AVX
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  8
+
+#define GGML_F32x8         __m256
+#define GGML_F32x8_ZERO    _mm256_setzero_ps()
+#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
+#define GGML_F32x8_LOAD    _mm256_loadu_ps
+#define GGML_F32x8_STORE   _mm256_storeu_ps
+#if defined(__FMA__)
+    #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
+#else
+    #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
+#endif
+#define GGML_F32x8_ADD     _mm256_add_ps
+#define GGML_F32x8_MUL     _mm256_mul_ps
+#define GGML_F32x8_REDUCE(res, x)                                 \
+do {                                                              \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \
+                                 _mm256_extractf128_ps(x[0], 1)); \
+    const __m128 t1 = _mm_hadd_ps(t0, t0);                        \
+    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1));        \
+} while (0)
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x8
+#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 AVX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  8
+
+// F16 arithmetic is not supported by AVX, so we use F32 instead
+
+#define GGML_F32Cx8             __m256
+#define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
+#define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
+
+#if defined(__F16C__)
+// the  _mm256_cvt intrinsics require F16C
+#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
+#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
+#else
+static inline __m256 __avx_f32cx8_load(const ggml_fp16_t * x) {
+    float tmp[8];
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return _mm256_loadu_ps(tmp);
+}
+static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
+    float arr[8];
+
+    _mm256_storeu_ps(arr, y);
+
+    for (int i = 0; i < 8; i++)
+        x[i] = GGML_FP32_TO_FP16(arr[i]);
+}
+#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
+#endif
+
+#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD         _mm256_add_ps
+#define GGML_F32Cx8_MUL         _mm256_mul_ps
+#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC                GGML_F32Cx8
+#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
+
+#elif defined(__POWER9_VECTOR__)
+
+#define GGML_SIMD
+
+// F32 POWER9
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              vector float
+#define GGML_F32x4_ZERO         0.0f
+#define GGML_F32x4_SET1         vec_splats
+#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
+#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD          vec_add
+#define GGML_F32x4_MUL          vec_mul
+#define GGML_F32x4_REDUCE(res, x)              \
+{                                              \
+    int offset = GGML_F32_ARR >> 1;            \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    res = vec_extract(x[0], 0) +               \
+          vec_extract(x[0], 1) +               \
+          vec_extract(x[0], 2) +               \
+          vec_extract(x[0], 3);                \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 POWER9
+#define GGML_F16_STEP       GGML_F32_STEP
+#define GGML_F16_EPR        GGML_F32_EPR
+#define GGML_F16_VEC        GGML_F32x4
+#define GGML_F16_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F16_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F16_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F16_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F16_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
+// Use vec_xl, not vec_ld, in case the load address is not aligned.
+#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \
+  vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
+  vec_extract_fp32_from_shortl(vec_xl(0, p))
+static inline unsigned char ggml_endian_byte(int i) {
+       uint16_t tmp_val = 1;
+       return ((unsigned char *)&tmp_val)[i];
+}
+#define GGML_ENDIAN_BYTE(i) ggml_endian_byte(i)
+#define GGML_F16_VEC_STORE(p, r, i)                             \
+  if (i & 0x1)                                                  \
+    vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \
+                                   r[i - GGML_ENDIAN_BYTE(0)]), \
+            0, p - GGML_F16_EPR)
+
+#elif defined(__wasm_simd128__)
+
+#define GGML_SIMD
+
+// F32 WASM
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              v128_t
+#define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f)
+#define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x)
+#define GGML_F32x4_LOAD         wasm_v128_load
+#define GGML_F32x4_STORE        wasm_v128_store
+#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
+#define GGML_F32x4_ADD          wasm_f32x4_add
+#define GGML_F32x4_MUL          wasm_f32x4_mul
+#define GGML_F32x4_REDUCE(res, x)                  \
+{                                                  \
+    int offset = GGML_F32_ARR >> 1;                \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 WASM
+
+#define GGML_F16_STEP 16
+#define GGML_F16_EPR  4
+
+inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(p[0]);
+    tmp[1] = GGML_FP16_TO_FP32(p[1]);
+    tmp[2] = GGML_FP16_TO_FP32(p[2]);
+    tmp[3] = GGML_FP16_TO_FP32(p[3]);
+
+    return wasm_v128_load(tmp);
+}
+
+inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
+    float tmp[4];
+
+    wasm_v128_store(tmp, x);
+
+    p[0] = GGML_FP32_TO_FP16(tmp[0]);
+    p[1] = GGML_FP32_TO_FP16(tmp[1]);
+    p[2] = GGML_FP32_TO_FP16(tmp[2]);
+    p[3] = GGML_FP32_TO_FP16(tmp[3]);
+}
+
+#define GGML_F16x4             v128_t
+#define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f)
+#define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x)
+#define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x)
+#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
+#define GGML_F16x4_FMA         GGML_F32x4_FMA
+#define GGML_F16x4_ADD         wasm_f32x4_add
+#define GGML_F16x4_MUL         wasm_f32x4_mul
+#define GGML_F16x4_REDUCE(res, x)                           \
+{                                                           \
+    int offset = GGML_F16_ARR >> 1;                         \
+    for (int i = 0; i < offset; ++i) {                      \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);           \
+    }                                                       \
+    offset >>= 1;                                           \
+    for (int i = 0; i < offset; ++i) {                      \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);           \
+    }                                                       \
+    offset >>= 1;                                           \
+    for (int i = 0; i < offset; ++i) {                      \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);           \
+    }                                                       \
+    res = (ggml_float) (wasm_f32x4_extract_lane(x[0], 0) +  \
+          wasm_f32x4_extract_lane(x[0], 1) +                \
+          wasm_f32x4_extract_lane(x[0], 2) +                \
+          wasm_f32x4_extract_lane(x[0], 3));                \
+}
+
+#define GGML_F16_VEC                GGML_F16x4
+#define GGML_F16_VEC_ZERO           GGML_F16x4_ZERO
+#define GGML_F16_VEC_SET1           GGML_F16x4_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F16x4_FMA
+#define GGML_F16_VEC_ADD            GGML_F16x4_ADD
+#define GGML_F16_VEC_MUL            GGML_F16x4_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE
+
+#elif defined(__SSE3__)
+
+#define GGML_SIMD
+
+// F32 SSE
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4         __m128
+#define GGML_F32x4_ZERO    _mm_setzero_ps()
+#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
+#define GGML_F32x4_LOAD    _mm_loadu_ps
+#define GGML_F32x4_STORE   _mm_storeu_ps
+#if defined(__FMA__)
+    // TODO: Does this work?
+    #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
+#else
+    #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
+#endif
+#define GGML_F32x4_ADD     _mm_add_ps
+#define GGML_F32x4_MUL     _mm_mul_ps
+#define GGML_F32x4_REDUCE(res, x)                                 \
+{                                                                 \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \
+    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0));        \
+}
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 SSE
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  4
+
+static inline __m128 __sse_f16x4_load(const ggml_fp16_t * x) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(x[0]);
+    tmp[1] = GGML_FP16_TO_FP32(x[1]);
+    tmp[2] = GGML_FP16_TO_FP32(x[2]);
+    tmp[3] = GGML_FP16_TO_FP32(x[3]);
+
+    return _mm_loadu_ps(tmp);
+}
+
+static inline void __sse_f16x4_store(ggml_fp16_t * x, __m128 y) {
+    float arr[4];
+
+    _mm_storeu_ps(arr, y);
+
+    x[0] = GGML_FP32_TO_FP16(arr[0]);
+    x[1] = GGML_FP32_TO_FP16(arr[1]);
+    x[2] = GGML_FP32_TO_FP16(arr[2]);
+    x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
+
+#define GGML_F32Cx4             __m128
+#define GGML_F32Cx4_ZERO        _mm_setzero_ps()
+#define GGML_F32Cx4_SET1(x)     _mm_set1_ps(x)
+#define GGML_F32Cx4_LOAD(x)     __sse_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD         _mm_add_ps
+#define GGML_F32Cx4_MUL         _mm_mul_ps
+#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
+
+#define GGML_F16_VEC                 GGML_F32Cx4
+#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
+
+#elif defined(__loongarch_asx)
+
+#define GGML_SIMD
+
+// F32 LASX
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  8
+
+#define GGML_F32x8         __m256
+#define GGML_F32x8_ZERO    (__m256)__lasx_xvldi(0)
+#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
+#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
+#define GGML_F32x8_STORE(x,y)   __lasx_xvst((y), (x), 0)
+#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
+#define GGML_F32x8_ADD     __lasx_xvfadd_s
+#define GGML_F32x8_MUL     __lasx_xvfmul_s
+#define GGML_F32x8_REDUCE(res, x)                                 \
+do {                                                              \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
+    }                                                             \
+    float *tmp_p = (float *)&x[0]; \
+    res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7];  \
+} while (0)
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x8
+#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 LASX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  8
+
+// F16 arithmetic is not supported by LASX, so we use F32 instead
+
+#define GGML_F32Cx8          __m256
+#define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0)
+#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
+
+static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
+    __m256i a;
+    memcpy(&a, x, sizeof(ggml_fp16_t) * 8);
+    a = __lasx_xvpermi_d(a, 0 | (1 << 4));
+    return __lasx_xvfcvtl_s_h(a);
+}
+
+static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
+    __m256i a = __lasx_xvfcvt_h_s(y, y);
+    a = __lasx_xvpermi_d(a, 0 | (2 << 2));
+    memcpy(x, &a, sizeof(ggml_fp16_t) * 8);
+}
+#define GGML_F32Cx8_LOAD(x)     __lasx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
+
+#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD         __lasx_xvfadd_s
+#define GGML_F32Cx8_MUL         __lasx_xvfmul_s
+#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC                GGML_F32Cx8
+#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
+
+#elif defined(__loongarch_sx)
+
+#define GGML_SIMD
+
+// F32 LSX
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4         __m128
+#define GGML_F32x4_ZERO    __lsx_vldi(0)
+#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
+#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
+#define GGML_F32x4_STORE((x),(y))   __lsx_vst((y), (x), 0)
+#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
+#define GGML_F32x4_ADD     __lsx_vfadd_s
+#define GGML_F32x4_MUL     __lsx_vfmul_s
+#define GGML_F32x4_REDUCE(res, x)                                                     \
+{                                                                                     \
+    int offset = GGML_F32_ARR >> 1;                                                   \
+    for (int i = 0; i < offset; ++i) {                                                \
+        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
+    }                                                                                 \
+    offset >>= 1;                                                                     \
+    for (int i = 0; i < offset; ++i) {                                                \
+        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
+    }                                                                                 \
+    offset >>= 1;                                                                     \
+    for (int i = 0; i < offset; ++i) {                                                \
+        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
+    }                                                                                 \
+    __m128i tmp     = __lsx_vsrli_d((__m128i) x[0], 32);                              \
+    tmp             = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]);                    \
+    tmp             = __lsx_vpickev_w(__lsx_vldi(0), tmp);                            \
+    const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88);                                     \
+    tmp             = __lsx_vsrli_d((__m128i) t0, 32);                                \
+    tmp             = (__m128i) __lsx_vfadd_s((__m128) tmp, t0);                      \
+    tmp             = __lsx_vpickev_w(__lsx_vldi(0), tmp);                            \
+    res             = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 LSX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  4
+
+static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(x[0]);
+    tmp[1] = GGML_FP16_TO_FP32(x[1]);
+    tmp[2] = GGML_FP16_TO_FP32(x[2]);
+    tmp[3] = GGML_FP16_TO_FP32(x[3]);
+
+    return __lsx_vld(tmp, 0);
+}
+
+static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
+    float arr[4];
+
+    __lsx_vst(y, arr, 0);
+
+    x[0] = GGML_FP32_TO_FP16(arr[0]);
+    x[1] = GGML_FP32_TO_FP16(arr[1]);
+    x[2] = GGML_FP32_TO_FP16(arr[2]);
+    x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
+
+#define GGML_F32Cx4             __m128
+#define GGML_F32Cx4_ZERO        __lsx_vldi(0)
+#define GGML_F32Cx4_SET1(x)     __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
+#define GGML_F32Cx4_LOAD(x)     __lsx_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD         __lsx_vfadd_s
+#define GGML_F32Cx4_MUL         __lsx_vfmul_s
+#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
+
+#define GGML_F16_VEC                 GGML_F32Cx4
+#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
+
+#elif defined(__VXE__) || defined(__VXE2__)
+
+#define GGML_SIMD
+
+// F32 s390x
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              __vector float
+#define GGML_F32x4_ZERO         vec_splats(0.0f)
+#define GGML_F32x4_SET1         vec_splats
+#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
+#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD          vec_add
+#define GGML_F32x4_MUL          vec_mul
+#define GGML_F32x4_REDUCE(res, x)                   \
+{                                                   \
+    int offset = GGML_F32_ARR >> 1;                 \
+    for (int i = 0; i < offset; ++i) {              \
+        x[i] = vec_add(x[i], x[offset + i]);        \
+    }                                               \
+    offset >>= 1;                                   \
+    for (int i = 0; i < offset; ++i) {              \
+        x[i] = vec_add(x[i], x[offset + i]);        \
+    }                                               \
+    offset >>= 1;                                   \
+    for (int i = 0; i < offset; ++i) {              \
+        x[i] = vec_add(x[i], x[offset + i]);        \
+    }                                               \
+    res = vec_extract(x[0], 0) +                    \
+          vec_extract(x[0], 1) +                    \
+          vec_extract(x[0], 2) +                    \
+          vec_extract(x[0], 3);                     \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 s390x
+#define GGML_F16_STEP GGML_F32_STEP
+#define GGML_F16_EPR  GGML_F32_EPR
+
+static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) {
+    float tmp[4];
+
+    for (int i = 0; i < 4; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    // note: keep type-cast here to prevent compiler bugs
+    // see: https://github.com/ggml-org/llama.cpp/issues/12846
+    return vec_xl(0, (const float *)(tmp));
+}
+
+static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) {
+    float arr[4];
+
+    // note: keep type-cast here to prevent compiler bugs
+    // see: https://github.com/ggml-org/llama.cpp/issues/12846
+    vec_xst(y, 0, (float *)(arr));
+
+    for (int i = 0; i < 4; i++) {
+        x[i] = GGML_FP32_TO_FP16(arr[i]);
+    }
+}
+
+#define GGML_F16_VEC                GGML_F32x4
+#define GGML_F16_VEC_ZERO           GGML_F32x4_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32x4_SET1
+#define GGML_F16_VEC_LOAD(p, i)     __lzs_f16cx4_load(p)
+#define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32x4_FMA
+#define GGML_F16_VEC_ADD            GGML_F32x4_ADD
+#define GGML_F16_VEC_MUL            GGML_F32x4_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32x4_REDUCE
+
+#endif
+
+// GGML_F32_ARR / GGML_F16_ARR
+//   number of registers to use per step
+#ifdef GGML_SIMD
+#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
+#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
+#endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp
new file mode 100644
index 000000000..4fce569b3
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp
@@ -0,0 +1,186 @@
+#include "unary-ops.h"
+
+static inline float op_abs(float x) {
+    return fabsf(x);
+}
+
+static inline float op_sgn(float x) {
+    return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
+}
+
+static inline float op_neg(float x) {
+    return -x;
+}
+
+static inline float op_step(float x) {
+    return (x > 0.f) ? 1.f : 0.f;
+}
+
+static inline float op_tanh(float x) {
+    return tanhf(x);
+}
+
+static inline float op_elu(float x) {
+    return (x > 0.f) ? x : expm1f(x);
+}
+
+static inline float op_relu(float x) {
+    return (x > 0.f) ? x : 0.f;
+}
+
+static inline float op_sigmoid(float x) {
+    return 1.f / (1.f + expf(-x));
+}
+
+static inline float op_hardsigmoid(float x) {
+    return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
+}
+
+static inline float op_exp(float x) {
+    return expf(x);
+}
+
+static inline float op_hardswish(float x) {
+    return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
+}
+
+static inline float op_sqr(float x) {
+    return x * x;
+}
+
+static inline float op_sqrt(float x) {
+    return sqrtf(x);
+}
+
+static inline float op_sin(float x) {
+    return sinf(x);
+}
+
+static inline float op_cos(float x) {
+    return cosf(x);
+}
+
+static inline float op_log(float x) {
+    return logf(x);
+}
+
+template 
+static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
+    constexpr auto src0_to_f32 = type_conversion_table::to_f32;
+    constexpr auto f32_to_dst  = type_conversion_table::from_f32;
+
+    for (int i = 0; i < n; i++) {
+        y[i] = f32_to_dst(op(src0_to_f32(x[i])));
+    }
+}
+
+template 
+static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(dst_t));
+    GGML_ASSERT(nb00 == sizeof(src0_t));
+
+    const auto [ir0, ir1] = get_thread_range(params, src0);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir/(ne02*ne01);
+        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        dst_t        * dst_ptr  = (dst_t  *)       ((char *)       dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+        const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+        vec_unary_op(ne0, dst_ptr, src0_ptr);
+    }
+}
+
+// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
+template 
+static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    /*  */ if (src0->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32
+        apply_unary_op(params, dst);
+    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16
+        apply_unary_op(params, dst);
+    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
+        apply_unary_op(params, dst);
+    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
+        apply_unary_op(params, dst);
+    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F32) {
+        apply_unary_op(params, dst);
+    } else {
+        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
+            ggml_type_name(dst->type), ggml_type_name(src0->type));
+        GGML_ABORT("fatal error");
+    }
+}
+
+void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
+
+void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op(params, dst);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h
new file mode 100644
index 000000000..b1ade2c8e
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include "common.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void ggml_compute_forward_abs(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_sgn(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_neg(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_step(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_hardswish(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp
new file mode 100644
index 000000000..dfe2218e3
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp
@@ -0,0 +1,258 @@
+#include "vec.h"
+
+#include 
+
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+#endif
+
+// precomputed gelu table for f16 (128 KB)
+ggml_fp16_t ggml_table_gelu_f16[1 << 16];
+
+// precomputed quick gelu table for f16 (128 KB)
+ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
+
+void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {
+   assert(nrc == 1);
+   GGML_UNUSED(nrc);
+   GGML_UNUSED(bx);
+   GGML_UNUSED(by);
+   GGML_UNUSED(bs);
+
+#if defined(GGML_SIMD)
+    float sumf = 0.0f;
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+            sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    GGML_F32_VEC_REDUCE(sumf, sum);
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += x[i]*y[i];
+    }
+#else
+    // scalar
+    ggml_float sumf = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sumf += (ggml_float)(x[i]*y[i]);
+    }
+#endif
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc) {
+    assert(nrc == 1);
+    GGML_UNUSED(nrc);
+    GGML_UNUSED(bx);
+    GGML_UNUSED(by);
+    GGML_UNUSED(bs);
+    int i = 0;
+    ggml_float sumf = 0;
+
+#if defined(__AVX512BF16__)
+    __m512 c1 = _mm512_setzero_ps();
+    __m512 c2 = _mm512_setzero_ps();
+    for (; i + 64 <= n; i += 64) {
+        c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
+                             m512bh(_mm512_loadu_si512((y + i))));
+        c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
+                             m512bh(_mm512_loadu_si512((y + i + 32))));
+    }
+    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
+    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
+
+#elif defined(__AVX512F__)
+#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
+    __m512 c1 = _mm512_setzero_ps();
+    __m512 c2 = _mm512_setzero_ps();
+    for (; i + 32 <= n; i += 32) {
+        c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
+        c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
+    }
+    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
+    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
+
+#undef LOAD
+#elif defined(__AVX2__) || defined(__AVX__)
+#if defined(__AVX2__)
+#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
+#else
+#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
+#endif
+    __m256 c1 = _mm256_setzero_ps();
+    __m256 c2 = _mm256_setzero_ps();
+    __m256 c3 = _mm256_setzero_ps();
+    __m256 c4 = _mm256_setzero_ps();
+    for (; i + 32 <= n; i += 32) {
+        c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
+        c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
+        c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
+        c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
+    }
+    __m128 g;
+    c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
+                       _mm256_add_ps(c2, c4));
+    g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
+                   _mm256_castps256_ps128(c1));
+    g = _mm_add_ps(g, _mm_movehl_ps(g, g));
+    g = _mm_add_ss(g, _mm_movehdup_ps(g));
+    sumf += (ggml_float)_mm_cvtss_f32(g);
+
+#undef LOAD
+#endif
+
+    for (; i < n; ++i) {
+        sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
+                             GGML_BF16_TO_FP32(y[i]));
+    }
+    *s = sumf;
+}
+
+void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc) {
+    assert(nrc == 1);
+    GGML_UNUSED(nrc);
+    GGML_UNUSED(bx);
+    GGML_UNUSED(by);
+    GGML_UNUSED(bs);
+
+    ggml_float sumf = 0.0;
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    GGML_F16_VEC_REDUCE(sumf, sum);
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+    }
+#endif
+
+    *s = sumf;
+}
+
+void ggml_vec_silu_f32(const int n, float * y, const float * x) {
+    int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+    for (; i + 15 < n; i += 16) {
+        _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
+    }
+#elif defined(__AVX2__) && defined(__FMA__)
+    for (; i + 7 < n; i += 8) {
+        _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
+    }
+#elif defined(__SSE2__)
+    for (; i + 3 < n; i += 4) {
+        _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
+    }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    for (; i + 3 < n; i += 4) {
+        vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
+    }
+#endif
+    for (; i < n; ++i) {
+        y[i] = ggml_silu_f32(x[i]);
+    }
+}
+
+ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
+    int i = 0;
+    ggml_float sum = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+    for (; i + 15 < n; i += 16) {
+        __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
+                                               _mm512_set1_ps(max)));
+        _mm512_storeu_ps(y + i, val);
+        sum += (ggml_float)_mm512_reduce_add_ps(val);
+    }
+#elif defined(__AVX2__) && defined(__FMA__)
+    for (; i + 7 < n; i += 8) {
+        __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
+                                               _mm256_set1_ps(max)));
+        _mm256_storeu_ps(y + i, val);
+        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
+                                 _mm256_castps256_ps128(val));
+        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
+        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
+        sum += (ggml_float)_mm_cvtss_f32(val2);
+    }
+#elif defined(__SSE2__)
+    for (; i + 3 < n; i += 4) {
+        __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
+                                            _mm_set1_ps(max)));
+        _mm_storeu_ps(y + i, val);
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
+        val = _mm_add_ss(val, _mm_movehdup_ps(val));
+#else
+        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
+        val = _mm_add_ps(val, tmp);
+        tmp = _mm_movehl_ps(tmp, val);
+        val = _mm_add_ss(val, tmp);
+#endif
+        sum += (ggml_float)_mm_cvtss_f32(val);
+    }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    for (; i + 3 < n; i += 4) {
+        float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
+                                                vdupq_n_f32(max)));
+        vst1q_f32(y + i, val);
+        sum += (ggml_float)vaddvq_f32(val);
+    }
+#endif
+    for (; i < n; ++i) {
+        float val = expf(x[i] - max);
+        sum += (ggml_float)val;
+        y[i] = val;
+    }
+    return sum;
+}
+
+ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
+    // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
+
+    int i = 0;
+    ggml_float sum = 0;
+    for (; i < n; ++i) {
+        float val = x[i] - max;
+        y[i] = val;
+        sum += (ggml_float)expf(val);
+    }
+    return sum = (ggml_float)logf(sum);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h
new file mode 100644
index 000000000..23cbb3051
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h
@@ -0,0 +1,802 @@
+// Vectorized functions for fundamental operations
+
+#pragma once
+
+#include "ggml-impl.h"
+#include "simd-mappings.h"
+#include "ggml.h"
+
+#if defined(GGML_USE_ACCELERATE)
+#include 
+#endif
+
+// floating point type used to accumulate sums
+typedef double ggml_float;
+
+#define GGML_GELU_FP16
+#define GGML_GELU_QUICK_FP16
+
+#define GGML_SOFT_MAX_UNROLL 4
+#define GGML_VEC_DOT_UNROLL  2
+#define GGML_VEC_MAD_UNROLL  32
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//
+// global data
+//
+
+// precomputed gelu table for f16 (128 KB)
+extern ggml_fp16_t ggml_table_gelu_f16[1 << 16];
+
+// precomputed quick gelu table for f16 (128 KB)
+extern ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
+
+//
+// fundamental operations
+//
+
+void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
+void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
+void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
+
+void ggml_vec_silu_f32(const int n, float * y, const float * x);
+ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
+ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
+
+inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t   v) { for (int i = 0; i < n; ++i) x[i] = v;    }
+inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
+
+inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
+inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
+    for (int i = 0; i < n; ++i) {
+        z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) + GGML_FP16_TO_FP32(y[i]));
+    }
+}
+inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
+inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
+    for (int i = 0; i < n; ++i) {
+        z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) - GGML_FP16_TO_FP32(y[i]));
+    }
+}
+inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
+inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(-GGML_FP16_TO_FP32(x[i]));
+    }
+}
+
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
+inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
+    for (int i = 0; i < n; ++i) {
+        z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) * GGML_FP16_TO_FP32(y[i]));
+    }
+}
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
+inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
+    for (int i = 0; i < n; ++i) {
+        z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) / GGML_FP16_TO_FP32(y[i]));
+    }
+}
+
+// compute GGML_VEC_DOT_UNROLL dot products at once
+// xs - x row stride in bytes
+inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GGML_RESTRICT s, void * GGML_RESTRICT xv, ggml_fp16_t * GGML_RESTRICT y) {
+    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+
+    ggml_fp16_t * GGML_RESTRICT x[GGML_VEC_DOT_UNROLL];
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+
+                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+            }
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
+        }
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
+        }
+    }
+#endif
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        s[i] = (float)sumf[i];
+    }
+}
+
+inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#endif
+}
+
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#endif
+}
+
+// xs and vs are byte strides of x and v
+inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * GGML_RESTRICT y, const float * GGML_RESTRICT xv, const float * GGML_RESTRICT vv) {
+
+    const float * GGML_RESTRICT x[GGML_VEC_MAD_UNROLL];
+    const float * GGML_RESTRICT v[GGML_VEC_MAD_UNROLL];
+
+    for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
+        x[i] = (const float *) ((const char *) xv + i*xs);
+        v[i] = (const float *) ((const char *) vv + i*vs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
+
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        vx[k] = GGML_F32_VEC_SET1(v[k][0]);
+    }
+
+    GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+            for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+                ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
+                ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
+            }
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        for (int i = np; i < n; ++i) {
+            y[i] += x[k][i]*v[k][0];
+        }
+    }
+#else
+    // scalar
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        for (int i = 0; i < n; ++i) {
+            y[i] += x[k][i]*v[k][0];
+        }
+    }
+#endif
+}
+
+//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
+inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
+#if defined(GGML_USE_ACCELERATE)
+    vDSP_vsmul(y, 1, &v, y, 1, n);
+#elif defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] *= v;
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] *= v;
+    }
+#endif
+}
+
+inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+    }
+#endif
+}
+
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
+inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
+inline static void ggml_vec_sqr_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16(v*v);
+    }
+}
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
+inline static void ggml_vec_sqrt_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(sqrtf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
+inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  }
+inline static void ggml_vec_log_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(logf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
+inline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  }
+inline static void ggml_vec_sin_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(sinf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
+inline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  }
+inline static void ggml_vec_cos_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(cosf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
+inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
+inline static void ggml_vec_abs_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(fabsf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
+inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
+inline static void ggml_vec_sgn_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16((v > 0.f) ? 1.f : ((v < 0.f) ? -1.f : 0.f));
+    }
+}
+inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
+inline static void ggml_vec_step_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16((GGML_FP16_TO_FP32(x[i]) > 0.f) ? 1.f : 0.f);
+    }
+}
+inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
+inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(tanhf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
+inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
+inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(expm1f(GGML_FP16_TO_FP32(x[i])));
+    }
+}
+inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+inline static void ggml_vec_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v : 0.f);
+    }
+}
+inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
+inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const float ns) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f));
+    }
+}
+inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
+inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(1.f / (1.f + expf(-GGML_FP16_TO_FP32(x[i]))));
+    }
+}
+// TODO: optimize performance
+inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_hardswish_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16(v * fminf(1.0f, fmaxf(0.0f, (v + 3.0f) / 6.0f)));
+    }
+}
+inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_hardsigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(fminf(1.0f, fmaxf(0.0f, (GGML_FP16_TO_FP32(x[i]) + 3.0f) / 6.0f)));
+    }
+}
+inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
+inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(expf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
+
+static const float GELU_COEF_A     = 0.044715f;
+static const float GELU_QUICK_COEF = -1.702f;
+static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+inline static float ggml_gelu_f32(float x) {
+    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_table_gelu_f16[i16[i]];
+    }
+}
+
+#ifdef GGML_GELU_FP16
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        if (x[i] <= -10.0f) {
+            y[i] = 0.0f;
+        } else if (x[i] >= 10.0f) {
+            y[i] = x[i];
+        } else {
+            ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+            memcpy(&t, &fp16, sizeof(uint16_t));
+            y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
+        }
+    }
+}
+#else
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_f32(x[i]);
+    }
+}
+#endif
+
+inline static float ggml_gelu_quick_f32(float x) {
+    return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
+}
+
+//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+//    const uint16_t * i16 = (const uint16_t *) x;
+//    for (int i = 0; i < n; ++i) {
+//        y[i] = ggml_table_gelu_quick_f16[i16[i]];
+//    }
+//}
+
+#ifdef GGML_GELU_QUICK_FP16
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+        memcpy(&t, &fp16, sizeof(uint16_t));
+        y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
+    }
+}
+#else
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_quick_f32(x[i]);
+    }
+}
+#endif
+
+inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v))));
+    }
+}
+
+// Sigmoid Linear Unit (SiLU) function
+inline static float ggml_silu_f32(float x) {
+    return x/(1.0f + expf(-x));
+}
+inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
+    float v = GGML_FP16_TO_FP32(x);
+    return GGML_FP32_TO_FP16(v/(1.0f + expf(-v)));
+}
+
+#if __FINITE_MATH_ONLY__
+#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
+#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
+#endif
+
+#if defined(__ARM_NEON) && defined(__aarch64__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static float32x4_t ggml_v_expf(float32x4_t x) {
+    const float32x4_t r = vdupq_n_f32(0x1.8p23f);
+    const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
+    const float32x4_t n = vsubq_f32(z, r);
+    const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
+                                    vdupq_n_f32(0x1.7f7d1cp-20f));
+    const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
+    const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
+    const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
+    const float32x4_t u = vmulq_f32(b, b);
+    const float32x4_t j = vfmaq_f32(
+        vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
+        vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
+                  vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
+    if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
+        return vfmaq_f32(k, j, k);
+    const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
+    const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
+    const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
+    return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
+                     vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static float32x4_t ggml_v_silu(float32x4_t x) {
+    const float32x4_t one = vdupq_n_f32(1.0f);
+    const float32x4_t zero = vdupq_n_f32(0.0f);
+    const float32x4_t neg_x = vsubq_f32(zero, x);
+    const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
+    const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
+    return vdivq_f32(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__AVX512F__) && defined(__AVX512DQ__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m512 ggml_v_expf(__m512 x) {
+  const __m512 r = _mm512_set1_ps(0x1.8p23f);
+  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
+  const __m512 n = _mm512_sub_ps(z, r);
+  const __m512 b =
+      _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
+                       _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
+  const __mmask16 d =
+      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
+  const __m512 u = _mm512_mul_ps(b, b);
+  const __m512 j = _mm512_fmadd_ps(
+      _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
+                                      _mm512_set1_ps(0x1.573e2ep-5f)),
+                      u,
+                      _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
+                                      _mm512_set1_ps(0x1.fffdb6p-2f))),
+      u,
+      _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
+  const __m512 res = _mm512_scalef_ps(j, n);
+  if (_mm512_kortestz(d, d))
+    return res;
+  const __m512 zero = _mm512_setzero_ps();
+  const __m512 alt = _mm512_mask_blend_ps(
+      _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
+  return _mm512_mask_blend_ps(d, res, alt);
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m512 ggml_v_silu(__m512 x) {
+    const __m512 one = _mm512_set1_ps(1);
+    const __m512 zero = _mm512_setzero_ps();
+    const __m512 neg_x = _mm512_sub_ps(zero, x);
+    const __m512 exp_neg_x = ggml_v_expf(neg_x);
+    const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
+    return _mm512_div_ps(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__AVX2__) && defined(__FMA__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m256 ggml_v_expf(__m256 x) {
+  const __m256 r = _mm256_set1_ps(0x1.8p23f);
+  const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
+  const __m256 n = _mm256_sub_ps(z, r);
+  const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
+                                    _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
+  const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
+  const __m256 k = _mm256_castsi256_ps(
+      _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
+  const __m256i c = _mm256_castps_si256(
+      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+                    _mm256_set1_ps(126), _CMP_GT_OQ));
+  const __m256 u = _mm256_mul_ps(b, b);
+  const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
+                                                                   _mm256_set1_ps(0x1.573e2ep-5f)), u,
+                                                   _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
+                                                                   _mm256_set1_ps(0x1.fffdb6p-2f))),
+                                   u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
+  if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
+    return _mm256_fmadd_ps(j, k, k);
+  const __m256i g = _mm256_and_si256(
+      _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
+      _mm256_set1_epi32(0x82000000u));
+  const __m256 s1 =
+      _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
+  const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
+  const __m256i d = _mm256_castps_si256(
+      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+                    _mm256_set1_ps(192), _CMP_GT_OQ));
+  return _mm256_or_ps(
+      _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
+      _mm256_andnot_ps(
+          _mm256_castsi256_ps(d),
+          _mm256_or_ps(
+              _mm256_and_ps(_mm256_castsi256_ps(c),
+                            _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
+              _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m256 ggml_v_silu(__m256 x) {
+    const __m256 one = _mm256_set1_ps(1);
+    const __m256 zero = _mm256_setzero_ps();
+    const __m256 neg_x = _mm256_sub_ps(zero, x);
+    const __m256 exp_neg_x = ggml_v_expf(neg_x);
+    const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
+    return _mm256_div_ps(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
+
+#if defined(__FMA__)
+#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
+#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
+#else
+#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
+#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
+#endif
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m128 ggml_v_expf(__m128 x) {
+    const __m128 r = _mm_set1_ps(0x1.8p23f);
+    const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
+    const __m128 n = _mm_sub_ps(z, r);
+    const __m128 b =
+        NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
+    const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
+    const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
+    const __m128i c =
+        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
+    const __m128 u = _mm_mul_ps(b, b);
+    const __m128 j =
+        MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
+                        MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
+                u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
+    if (!_mm_movemask_epi8(c))
+        return MADD128(j, k, k);
+    const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
+                                    _mm_set1_epi32(0x82000000u));
+    const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
+    const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
+    const __m128i d =
+        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
+    return _mm_or_ps(
+        _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
+        _mm_andnot_ps(_mm_castsi128_ps(d),
+                      _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
+                                _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m128 ggml_v_silu(__m128 x) {
+    const __m128 one = _mm_set1_ps(1);
+    const __m128 zero = _mm_setzero_ps();
+    const __m128 neg_x = _mm_sub_ps(zero, x);
+    const __m128 exp_neg_x = ggml_v_expf(neg_x);
+    const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
+    return _mm_div_ps(x, one_plus_exp_neg_x);
+}
+
+#endif // __ARM_NEON / __AVX2__ / __SSE2__
+
+inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_silu_f16(x[i]);
+    }
+}
+
+inline static float ggml_silu_backward_f32(float x, float dy) {
+    const float s = 1.0f/(1.0f + expf(-x));
+    return dy*s*(1.0f + x*(1.0f - s));
+}
+
+inline static ggml_fp16_t ggml_silu_backward_f16(ggml_fp16_t x, ggml_fp16_t dy) {
+    const float v = GGML_FP16_TO_FP32(x);
+    const float s = 1.0f/(1.0f + expf(-v));
+    return GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(dy)*s*(1.0f + v*(1.0f - s)));
+}
+
+inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
+    for (int i = 0; i < n; ++i) {
+        dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
+    }
+}
+
+inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) {
+    for (int i = 0; i < n; ++i) {
+        dx[i] = ggml_silu_backward_f16(x[i], dy[i]);
+    }
+}
+
+inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+    ggml_float sum = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sum += (ggml_float)x[i];
+    }
+    *s = (float)sum;
+#else
+    vDSP_sve(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
+    ggml_float sum = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sum += (ggml_float)x[i];
+    }
+    *s = sum;
+}
+
+inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
+    float sum = 0.0f;
+    for (int i = 0; i < n; ++i) {
+        sum += GGML_FP16_TO_FP32(x[i]);
+    }
+    *s = sum;
+}
+
+inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
+    float sum = 0.0f;
+    for (int i = 0; i < n; ++i) {
+        sum += GGML_BF16_TO_FP32(x[i]);
+    }
+    *s = sum;
+}
+
+inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+    float max = -INFINITY;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+    }
+    *s = max;
+#else
+    vDSP_maxv(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
+    ggml_vec_norm_f32(n, s, x);
+    *s = 1.f/(*s);
+}
+
+inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
+    float max = -INFINITY;
+    int idx = 0;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+        if (max == x[i]) { idx = i; }
+    }
+    *s = idx;
+}
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
index 96bd5a0be..8623214c7 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
@@ -102,6 +102,15 @@ if (CUDAToolkit_FOUND)
 
     set(CUDA_FLAGS -use_fast_math)
 
+    if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
+        # Options are:
+        # - none (not recommended)
+        # - speed (nvcc's default)
+        # - balance
+        # - size
+        list(APPEND CUDA_FLAGS -compress-mode=${GGML_CUDA_COMPRESSION_MODE})
+    endif()
+
     if (GGML_FATAL_WARNINGS)
         list(APPEND CUDA_FLAGS -Werror all-warnings)
     endif()
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
index ce4b9cfb5..e1fbf0e13 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
@@ -294,11 +294,13 @@ static void ggml_cuda_op_bin_bcast(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
 
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
 
     if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
         op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
-    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+        op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
         op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
         op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/clamp.cu b/ml/backend/ggml/ggml/src/ggml-cuda/clamp.cu
index 8009a3e3d..fe415e7f7 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/clamp.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/clamp.cu
@@ -1,34 +1,45 @@
 #include "clamp.cuh"
 
-static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
+static __device__ __forceinline__ float op_clamp(float x, float min, float max) {
+    return fminf(fmaxf(x, min), max);
+}
+
+template 
+static __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
     }
 
-    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
+    dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max);
 }
 
-static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
+template 
+static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
-    clamp_f32<<>>(x, dst, min, max, k);
+    op_clamp_kernel<<>>(x, dst, min, max, k);
 }
 
 
 void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
     float min;
     float max;
     memcpy(&min, dst->op_params, sizeof(float));
     memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 
-    clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        clamp_cuda((const half *)src0_d, (half *)dst_d, (half)min, (half)max, ggml_nelements(src0), stream);
+    } else {
+        clamp_cuda((const float *)src0_d, (float *)dst_d, (float)min, (float)max, ggml_nelements(src0), stream);
+    }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
index b24593fc2..a718b6a12 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
@@ -41,15 +41,18 @@
 #define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
 #define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
 
-#define GGML_CUDA_CC_PASCAL       600
-#define GGML_CUDA_CC_DP4A         610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
-#define GGML_CUDA_CC_VOLTA        700
-#define GGML_CUDA_CC_TURING       750
-#define GGML_CUDA_CC_AMPERE       800
-#define GGML_CUDA_CC_ADA_LOVELACE 890
-#define GGML_CUDA_CC_OFFSET_AMD   0x1000000
+#define GGML_CUDA_CC_PASCAL          600
+#define GGML_CUDA_CC_DP4A            610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define GGML_CUDA_CC_VOLTA           700
+#define GGML_CUDA_CC_TURING          750
+#define GGML_CUDA_CC_AMPERE          800
+#define GGML_CUDA_CC_ADA_LOVELACE    890
+#define GGML_CUDA_CC_OFFSET_AMD      0x1000000
+#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
+#define GGML_CUDA_CC_IS_NVIDIA(cc)   (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
 
-// GCN/CNDA, wave size is 64
+// AMD
+// GCN/CDNA, wave size is 64
 #define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 0x803)  // Tonga, Fiji, Polaris, minimum for fast fp16
 #define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 0x900)  // Vega56/64, minimum for fp16 dual issue
 #define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 0x906)  // MI50/Radeon VII, minimum for dp4a
@@ -57,12 +60,13 @@
 #define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x910)  // MI210, minimum acc register renameing
 #define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x942)  // MI300
 
-// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
+// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
 #define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
 #define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
 #define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
 #define GGML_CUDA_CC_RDNA4      (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
 
+#define GGML_CUDA_CC_IS_AMD(cc)   (cc >= GGML_CUDA_CC_OFFSET_AMD)
 #define GGML_CUDA_CC_IS_RDNA(cc)  (cc >= GGML_CUDA_CC_RDNA1)
 #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
 #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
@@ -71,8 +75,17 @@
 #define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
 #define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
 
-#define GGML_CUDA_CC_QY1        210
-#define GGML_CUDA_CC_QY2        220
+// Moore Threads
+#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
+
+#define GGML_CUDA_CC_QY1  (GGML_MUSA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
+#define GGML_CUDA_CC_QY2  (GGML_MUSA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
+#define GGML_CUDA_CC_NG   (GGML_MUSA_CC_OFFSET_MTHREADS + 0x310) // TBD
+
+#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
+#define GGML_CUDA_CC_IS_QY1(cc)      (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
+#define GGML_CUDA_CC_IS_QY2(cc)      (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NEXT)
+#define GGML_CUDA_CC_IS_NG(cc)       (cc >= GGML_CUDA_CC_NG)
 
 #ifdef __CUDA_ARCH_LIST__
 constexpr bool ggml_cuda_has_arch_impl(int) {
@@ -198,6 +211,10 @@ typedef float2 dfloat2;
 #define FP16_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
+#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
+#define FP16_MMA_AVAILABLE
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
+
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 #define NEW_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -206,36 +223,42 @@ typedef float2 dfloat2;
 #define CP_ASYNC_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
-#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
 #define FLASH_ATTN_AVAILABLE
-#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
 
 static bool fp16_available(const int cc) {
     return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
 }
 
 static bool fast_fp16_available(const int cc) {
-    return fp16_available(cc) && cc != 610;
+    return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
 }
 
 // To be used for feature selection of external libraries, e.g. cuBLAS.
 static bool fast_fp16_hardware_available(const int cc) {
-    return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
+    return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
 }
 
 // Any FP16 tensor core instructions are available for ggml code.
 static bool fp16_mma_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
+    return false;
+#else
+    return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
+        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
 }
 
 // To be used for feature selection of external libraries, e.g. cuBLAS.
 static bool fp16_mma_hardware_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
+    return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
+        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
 }
 
 // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
 static bool new_mma_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
+    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
 }
 
 static bool cp_async_available(const int cc) {
@@ -265,6 +288,10 @@ static __device__ void no_device_code(
     __trap();
 
     GGML_UNUSED(no_device_code); // suppress unused function warning
+
+#if defined(GGML_USE_MUSA)
+    __builtin_unreachable();
+#endif // defined(GGML_USE_MUSA)
 }
 
 #ifdef __CUDA_ARCH__
@@ -386,11 +413,11 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
 
 static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
+#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
     c = __builtin_amdgcn_sdot4(a, b, c, false);
 #elif defined(RDNA3) || defined(RDNA4)
     c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
-#elif defined(__gfx1010__) || defined(__gfx900__)
+#elif defined(RDNA1) || defined(__gfx900__)
     int tmp1;
     int tmp2;
     asm("\n \
@@ -669,7 +696,7 @@ struct ggml_tensor_extra_gpu {
 };
 
 
-#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
+#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
 #define USE_CUDA_GRAPH
 #endif
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu b/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu
index aafbaf803..e9ffd274b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu
@@ -38,7 +38,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float *
         blockIdx.y * ne0 +
         blockIdx.z * ne0 * gridDim.y;
 
-    if (blockIdx.y < ne01) { // src0
+    if (blockIdx.y < (unsigned)ne01) { // src0
         int offset_src =
             nidx +
             blockIdx.y * ne0 +
@@ -64,7 +64,7 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float *
         blockIdx.y * ne0 +
         blockIdx.z * ne0 * gridDim.y;
 
-    if (blockIdx.z < ne02) { // src0
+    if (blockIdx.z < (unsigned)ne02) { // src0
         int offset_src =
             nidx +
             blockIdx.y * ne0 +
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/conv-transpose-1d.cu b/ml/backend/ggml/ggml/src/ggml-cuda/conv-transpose-1d.cu
index b1e94d6f7..fe4caf674 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/conv-transpose-1d.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/conv-transpose-1d.cu
@@ -34,6 +34,10 @@ static  __global__ void conv_transpose_1d_kernel(
         }
     }
     dst[global_index] = accumulator;
+    GGML_UNUSED(p0); GGML_UNUSED(d0); GGML_UNUSED(src0_ne3);
+    GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3);
+    GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1);
+    GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2);
 }
 
 static void conv_transpose_1d_f32_f32_cuda(
@@ -75,8 +79,6 @@ void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor
     const int p0 = 0;//opts[3];
     const int d0 = 1;//opts[4];
 
-    const int64_t kernel_size = ggml_nelements(src0);
-    const int64_t input_size = ggml_nelements(src1);
     const int64_t output_size = ggml_nelements(dst);
 
     conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
index 795b720d6..a224ec0e1 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
@@ -577,9 +577,9 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
         return;
     }
 
-    const src_t * x = (src_t *) vx;
+    const src_t * x = (const src_t *) vx;
 
-    y[i] = x[i];
+    y[i] = float(x[i]);
 }
 
 template 
@@ -588,6 +588,17 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
     convert_unary<<>>(vx, y, k);
 }
 
+to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_F32:
+            return convert_unary_cuda;
+        case GGML_TYPE_F16:
+            return convert_unary_cuda;
+        default:
+            return nullptr;
+    }
+}
+
 to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_Q4_0:
@@ -633,6 +644,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
             return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_F32:
             return convert_unary_cuda;
+        case GGML_TYPE_BF16:
+            return convert_unary_cuda;
         default:
             return nullptr;
     }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cuh
index 5394be9f1..411a13cf1 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cuh
@@ -7,7 +7,10 @@ using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, in
 
 typedef to_t_cuda_t to_fp32_cuda_t;
 typedef to_t_cuda_t to_fp16_cuda_t;
+typedef to_t_cuda_t to_bf16_cuda_t;
 
 to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
 
+to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
+
 to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
index cca2bee0b..8396df28c 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
@@ -10,6 +10,13 @@ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
     *dsti = *xi;
 }
 
+static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
+
+    *dsti = *xi;
+}
+
 static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
     const float * xi = (const float *) cxi;
     half * dsti = (half *) cdsti;
@@ -359,6 +366,16 @@ static void ggml_cpy_f32_f32_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_f32_bf16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_f16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -545,9 +562,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
         GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
         CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
+        ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
         ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -593,6 +612,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
         return nullptr;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
         return (void*) cpy_f32_f16;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
+        return (void*) cpy_f32_f16;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
         return (void*) cpy_f32_f16;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
index 7b9566fb4..56121705b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
@@ -52,7 +52,7 @@ typedef half (*vec_dot_KQ_f16_t)(
 typedef float (*vec_dot_KQ_f32_t)(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
 
-template
+template
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
@@ -62,7 +62,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     T sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const int ib    = k_KQ /  QI8_1;
@@ -70,7 +70,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
         const int shift = k_KQ & (QI8_1/2);
 
         const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
-        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+        const int u = Q_q8[k_KQ_0/warp_size];
 
         const int sumi = ggml_cuda_dp4a(v, u, 0);
 
@@ -78,21 +78,21 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
         if (std::is_same::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
-            const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
             sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
         } else
 #endif // FP16_AVAILABLE
         {
             const float2 * Q_ds = (const float2 *) Q_ds_v;
 
-            sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+            sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
         }
     }
 
     return sum;
 }
 
-template
+template
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
@@ -102,7 +102,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     T sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const int ib    = k_KQ /  QI8_1;
@@ -110,7 +110,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
         const int shift = k_KQ & (QI8_1/2);
 
         const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
-        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+        const int u = Q_q8[k_KQ_0/warp_size];
 
         const int sumi = ggml_cuda_dp4a(v, u, 0);
 
@@ -118,7 +118,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
         if (std::is_same::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
-            const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
             const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
             sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
         } else
@@ -126,8 +126,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
         {
             const float2 * Q_ds = (const float2 *) Q_ds_v;
 
-            const float sumid4d8   =  __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
-            const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+            const float sumid4d8   =  __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
+            const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
 
             sum += (T) (sumid4d8 + m4s8scaled);
         }
@@ -136,7 +136,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     return sum;
 }
 
-template
+template
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
@@ -146,7 +146,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     T sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const int ib    = k_KQ /  QI8_1;
@@ -161,7 +161,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
         v |= (vh << 18) & 0x00100000; // 2 -> 20
         v |= (vh << 25) & 0x10000000; // 3 -> 28
 
-        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+        const int u = Q_q8[k_KQ_0/warp_size];
 
         const int sumi = ggml_cuda_dp4a(v, u, 0);
 
@@ -169,21 +169,21 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
         if (std::is_same::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
-            const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
             sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
         } else
 #endif // FP16_AVAILABLE
         {
             const float2 * Q_ds = (const float2 *) Q_ds_v;
 
-            sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+            sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
         }
     }
 
     return sum;
 }
 
-template
+template
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
@@ -193,7 +193,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     T sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const int ib    = k_KQ /  QI8_1;
@@ -208,7 +208,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
         v |= (vh << 18) & 0x00100000; // 2 -> 20
         v |= (vh << 25) & 0x10000000; // 3 -> 28
 
-        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+        const int u = Q_q8[k_KQ_0/warp_size];
 
         const int sumi = ggml_cuda_dp4a(v, u, 0);
 
@@ -216,7 +216,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
         if (std::is_same::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
-            const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
             const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
             sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
         } else
@@ -224,8 +224,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
         {
             const float2 * Q_ds = (const float2 *) Q_ds_v;
 
-            const float sumid5d8   =  __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
-            const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+            const float sumid5d8   =  __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
+            const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
 
             sum += (T) (sumid5d8 + m5s8scaled);
         }
@@ -234,7 +234,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     return sum;
 }
 
-template 
+template 
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
@@ -244,7 +244,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     T sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const int ib  = k_KQ / QI8_0;
@@ -255,19 +255,19 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
         T Q_d;
         if (std::is_same::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
-            Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
+            Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
         } else {
             const float2 * Q_ds = (const float2 *) Q_ds_v;
-            Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
+            Q_d = Q_ds[k_KQ_0/warp_size].x;
         }
 
-        sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
+        sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
     }
 
     return sum;
 }
 
-template 
+template 
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
 
@@ -282,11 +282,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
         half2 sum2 = make_half2(0.0f, 0.0f);
 
 #pragma unroll
-        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
             const int k_KQ = k_KQ_0 + threadIdx.x;
 
             const half2 K_ik = K_h2[k_KQ];
-            sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+            sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
         }
 
         return __low2half(sum2) + __high2half(sum2);
@@ -298,12 +298,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
     float sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const half2 K_ik = K_h2[k_KQ];
-        sum +=  __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
-        sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
+        sum +=  __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
+        sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
     }
 
     return sum;
@@ -315,14 +315,14 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
 
     float vals[sizeof(int)] = {0.0f};
 #pragma unroll
-    for (int l = 0; l < sizeof(int); ++l) {
+    for (int l = 0; l < int(sizeof(int)); ++l) {
         vals[l] = scale * x[4*threadIdx.x + l];
     }
 
     float amax = fabsf(vals[0]);
     float sum  = vals[0];
 #pragma unroll
-    for (int l = 1; l < sizeof(int); ++l) {
+    for (int l = 1; l < int(sizeof(int)); ++l) {
         amax = fmaxf(amax, fabsf(vals[l]));
         sum += vals[l];
     }
@@ -338,7 +338,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
 
     if (d != 0.0f) {
 #pragma unroll
-        for (int l = 0; l < sizeof(int); ++l) {
+        for (int l = 0; l < int(sizeof(int)); ++l) {
             q8[l] = roundf(vals[l] / d);
         }
     }
@@ -474,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
     return x[i];
 }
 
-template 
+template 
 constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
-    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 :
-        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 :
-        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 :
-        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 :
-        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 :
-        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 :
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 :
         nullptr;
 }
 
-template 
+template 
 constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
-    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 :
-        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 :
-        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 :
-        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 :
-        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 :
-        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 :
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 :
         nullptr;
 }
 
@@ -606,50 +606,50 @@ static __global__ void flash_attn_stream_k_fixup(
     *dst = dst_val / rowsum;
 }
 
-template // D == head size
+template // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_combine_results(
         const float  * __restrict__ VKQ_parts,
         const float2 * __restrict__ VKQ_meta,
-        float * __restrict__ dst) {
-    VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
-    VKQ_meta  += parallel_blocks   * gridDim.y*blockIdx.x;
-    dst       +=                 D * gridDim.y*blockIdx.x;
+        float * __restrict__ dst,
+        const int parallel_blocks) {
+    VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
+    VKQ_meta  += parallel_blocks   * gridDim.z*blockIdx.x;
+    dst       +=                 D * gridDim.z*blockIdx.x;
 
     const int tid = threadIdx.x;
     __builtin_assume(tid < D);
 
-    __shared__ float2 meta[parallel_blocks];
+    extern __shared__ float2 meta[];
     if (tid < 2*parallel_blocks) {
-        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
+        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
     }
 
     __syncthreads();
 
     float kqmax = meta[0].x;
-#pragma unroll
     for (int l = 1; l < parallel_blocks; ++l) {
         kqmax = max(kqmax, meta[l].x);
     }
 
     float VKQ_numerator   = 0.0f;
     float VKQ_denominator = 0.0f;
-#pragma unroll
     for (int l = 0; l < parallel_blocks; ++l) {
         const float diff = meta[l].x - kqmax;
-        const float KQ_max_scale = expf(diff);
+        float KQ_max_scale = expf(diff);
         const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
         *((uint32_t *) &KQ_max_scale) &= ftz_mask;
 
-        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
         VKQ_denominator += KQ_max_scale * meta[l].y;
     }
 
-    dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
+    dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
 }
 
+[[noreturn]]
 static void on_no_fattn_vec_case(const int D) {
     if (D == 64) {
         fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
@@ -671,11 +671,10 @@ static void on_no_fattn_vec_case(const int D) {
     }
 }
 
-// parallel_blocks == 0 is stream-k decomposition
-template 
+template 
 void launch_fattn(
-    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
-    const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
+    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
+    const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
@@ -747,12 +746,14 @@ void launch_fattn(
         nb23 = nb23*bs*sizeof(half)/ts;
     }
 
+    int parallel_blocks = 1;
+
     const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
     const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
-    const dim3 block_dim(WARP_SIZE, nwarps, 1);
+    const dim3 block_dim(warp_size, nwarps, 1);
     dim3 blocks_num;
-    if (parallel_blocks == 0) {
+    if (stream_k) {
         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
         const int max_blocks = 2*nsm;
         const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
@@ -768,9 +769,43 @@ void launch_fattn(
 
         dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
     } else {
-        blocks_num.x = parallel_blocks*ntiles_x;
-        blocks_num.y = Q->ne[2];
-        blocks_num.z = Q->ne[3];
+        GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
+        const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
+
+        int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
+        CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
+
+        // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
+        parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
+
+        // parallel_blocks must not be larger than what the tensor size allows:
+        parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
+
+        // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
+        // Test whether parallel_blocks can be set to a higher value for better efficiency.
+        const int blocks_per_wave = nsm * max_blocks_per_sm;
+        int nwaves_best = 0;
+        int efficiency_percent_best = 0;
+        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
+            const int nblocks_total = ntiles_total * parallel_blocks_test;
+            const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
+            const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
+
+            // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
+            if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
+                break;
+            }
+
+            if (efficiency_percent > efficiency_percent_best) {
+                nwaves_best = nwaves;
+                efficiency_percent_best = efficiency_percent;
+                parallel_blocks = parallel_blocks_test;
+            }
+        }
+
+        blocks_num.x = ntiles_x;
+        blocks_num.y = parallel_blocks;
+        blocks_num.z = Q->ne[2]*Q->ne[3];
 
         if (parallel_blocks > 1) {
             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -796,12 +831,13 @@ void launch_fattn(
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+    GGML_ASSERT(block_dim.x % warp_size == 0);
     fattn_kernel<<>>(
         (const char *) Q->data,
         K_data,
         V_data,
         mask ? ((const char *) mask->data) : nullptr,
-        (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
+        !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -813,7 +849,7 @@ void launch_fattn(
     );
     CUDA_CHECK(cudaGetLastError());
 
-    if constexpr (parallel_blocks == 0) {
+    if (stream_k) {
         if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
             const dim3 block_dim_combine(D, 1, 1);
             const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
@@ -822,13 +858,14 @@ void launch_fattn(
                 <<>>
                 ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
         }
-    } else if constexpr (parallel_blocks > 1) {
+    } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(D, 1, 1);
-        const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+        const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
+        const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
 
-        flash_attn_combine_results
-            <<>>
-            (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+        flash_attn_combine_results
+            <<>>
+            (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
     }
     CUDA_CHECK(cudaGetLastError());
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 718ee5402..04804a15c 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -406,6 +406,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #endif // CP_ASYNC_AVAILABLE
 
 #else
+    GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
+    GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
+    GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV);
+    GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+    GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+    GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
+    GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
+    GGML_UNUSED(kb0);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
@@ -797,6 +806,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         __syncthreads();
     }
 #else
+    GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
+    GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
+    GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
+    GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask);
+    GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
@@ -931,6 +946,16 @@ static __global__ void flash_attn_ext_f16(
         (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
          ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
 #else
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
+    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
+    GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+    GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
+    GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
 }
@@ -970,7 +995,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
         fattn_kernel = flash_attn_ext_f16;
     }
 
-    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
+    launch_fattn
+        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
 }
 
 
@@ -984,38 +1010,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
     extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,   8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,   8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,   8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,   8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,   8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,   8)
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  16)
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  32)
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  64)
 
 // Kernels with ncols == 128 are only 4% faster due to register pressure.
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
index ef3569fab..e0039e175 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
@@ -4,7 +4,7 @@
 
 #define FATTN_KQ_STRIDE_TILE_F16 64
 
-template // D == head size
+template // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
 
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
-    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
-    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
     const half   * maskh = (const half   *)  mask + ne11*ic0;
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
 
     __syncthreads();
 
-    const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
-    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         half kqmax_new[ncols/nwarps];
@@ -271,24 +269,36 @@ static __global__ void flash_attn_tile_ext_f16(
             const int i0 = i00 + 2*threadIdx.x;
 
             half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
-            if (parallel_blocks == 1) {
+            if (gridDim.y == 1) {
                 dst_val /= __half2half2(kqsum_j);
             }
-            const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] =  __low2float(dst_val);
-            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
+            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
+            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] =  __low2float(dst_val);
+            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
         }
 
-        if (parallel_blocks != 1 && threadIdx.x == 0) {
-            dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+        if (gridDim.y != 1 && threadIdx.x == 0) {
+            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
-   NO_DEVICE_CODE;
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+    GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+    GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+    GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
-template 
+template 
 void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
@@ -296,15 +306,17 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    D             = 64;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16;
+            launch_fattn
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
         } break;
         case 128: {
             constexpr int    D             = 128;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16;
+            launch_fattn
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -324,37 +336,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
 
     if (Q->ne[1] <= 16) {
         constexpr int cols_per_block = 16;
-        constexpr int parallel_blocks = 4;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            launch_fattn_tile_f16_64_128(ctx, dst);
+            launch_fattn_tile_f16_64_128(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            launch_fattn_tile_f16_64_128(ctx, dst);
-        }
-        return;
-    }
-
-    if (Q->ne[1] <= 32) {
-        constexpr int cols_per_block = 32;
-        constexpr int parallel_blocks = 4;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            launch_fattn_tile_f16_64_128(ctx, dst);
-        } else {
-            constexpr bool use_logit_softcap = true;
-            launch_fattn_tile_f16_64_128(ctx, dst);
+            launch_fattn_tile_f16_64_128(ctx, dst);
         }
         return;
     }
 
     constexpr int cols_per_block = 32;
-    constexpr int parallel_blocks = 1;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        launch_fattn_tile_f16_64_128(ctx, dst);
+        launch_fattn_tile_f16_64_128(ctx, dst);
     } else {
         constexpr bool use_logit_softcap = true;
-        launch_fattn_tile_f16_64_128(ctx, dst);
+        launch_fattn_tile_f16_64_128(ctx, dst);
     }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
index 04b69c83b..fcb6f848f 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -4,7 +4,7 @@
 
 #define FATTN_KQ_STRIDE_TILE_F32 32
 
-template // D == head size
+template // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -52,24 +52,35 @@ static __global__ void flash_attn_tile_ext_f32(
     return;
 #endif // FP16_MMA_AVAILABLE
     if (use_logit_softcap && !(D == 128 || D == 256)) {
+        GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+        GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+        GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+        GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+        GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+        GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+        GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+        GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+        GGML_UNUSED(ne2); GGML_UNUSED(ne3);
         NO_DEVICE_CODE;
         return;
     }
 
     // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
-    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
-    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
     const half   * maskh = (const half   *)  mask + ne11*ic0;
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
@@ -103,8 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
 
     __syncthreads();
 
-    const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
-    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         float kqmax_new[ncols/nwarps];
@@ -269,25 +279,37 @@ static __global__ void flash_attn_tile_ext_f32(
             const int i0 = i00 + 2*threadIdx.x;
 
             float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
-            if (parallel_blocks == 1) {
+            if (gridDim.y == 1) {
                 dst_val.x /= kqsum_j;
                 dst_val.y /= kqsum_j;
             }
-            const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
-            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
+            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
+            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
+            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
         }
 
-        if (parallel_blocks != 1 && threadIdx.x == 0) {
-            dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+        if (gridDim.y != 1 && threadIdx.x == 0) {
+            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+    GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+    GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+    GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // FLASH_ATTN_AVAILABLE
 }
 
-template 
+template 
 void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
@@ -295,15 +317,17 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    D             = 64;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32;
+            launch_fattn
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
         } break;
         case 128: {
             constexpr int    D             = 128;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32;
+            launch_fattn
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -320,37 +344,22 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
 
     if (Q->ne[1] <= 16) {
         constexpr int cols_per_block = 16;
-        constexpr int parallel_blocks = 4;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            launch_fattn_tile_f32_64_128(ctx, dst);
+            launch_fattn_tile_f32_64_128(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            launch_fattn_tile_f32_64_128(ctx, dst);
-        }
-        return;
-    }
-
-    if (Q->ne[1] <= 32) {
-        constexpr int cols_per_block = 32;
-        constexpr int parallel_blocks = 4;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            launch_fattn_tile_f32_64_128(ctx, dst);
-        } else {
-            constexpr bool use_logit_softcap = true;
-            launch_fattn_tile_f32_64_128(ctx, dst);
+            launch_fattn_tile_f32_64_128(ctx, dst);
         }
         return;
     }
 
     constexpr int cols_per_block = 32;
-    constexpr int parallel_blocks = 1;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        launch_fattn_tile_f32_64_128(ctx, dst);
+        launch_fattn_tile_f32_64_128(ctx, dst);
     } else {
         constexpr bool use_logit_softcap = true;
-        launch_fattn_tile_f32_64_128(ctx, dst);
+        launch_fattn_tile_f32_64_128(ctx, dst);
     }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index b7686c1ec..e17d2d0e4 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -1,7 +1,7 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 
-template // D == head size
+template // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -55,17 +55,16 @@ static __global__ void flash_attn_vec_ext_f16(
     constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
     constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
 
-    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
-    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.y              + nb01*ic0;
-    K += nb12*(blockIdx.y / gqa_ratio);
-    V += nb22*(blockIdx.y / gqa_ratio);
+    Q += nb02* blockIdx.z              + nb01*ic0;
+    K += nb12*(blockIdx.z / gqa_ratio);
+    V += nb22*(blockIdx.z / gqa_ratio);
 
     const half * maskh = (const half   *)  mask + ne11*ic0;
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -172,8 +171,7 @@ static __global__ void flash_attn_vec_ext_f16(
 
     half2 VKQ[ncols] = {{0.0f, 0.0f}};
 
-    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
-    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
@@ -283,29 +281,41 @@ static __global__ void flash_attn_vec_ext_f16(
         kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
 
         half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
-        if (parallel_blocks == 1) {
+        if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
+        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
     }
 
-    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+    if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
-   NO_DEVICE_CODE;
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+    GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+    GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+    GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
-template 
+template 
 void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     constexpr int nwarps = D/WARP_SIZE;
-    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
 }
 
 template 
@@ -325,65 +335,48 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
     if (Q->ne[1] == 1) {
-        constexpr int cols_per_block  = 1;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 1;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
         }
         return;
     }
 
     if (Q->ne[1] == 2) {
-        constexpr int cols_per_block  = 2;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 2;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
         }
         return;
     }
 
     if (Q->ne[1] <= 4) {
-        constexpr int cols_per_block  = 4;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 4;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
         }
         return;
     }
 
-    if (Q->ne[1] <= 8) {
-        constexpr int cols_per_block  = 8;
-        constexpr int parallel_blocks = 4;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
-        } else {
-            constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
-        }
-        return;
-    }
-
-    constexpr int cols_per_block  = 8;
-    constexpr int parallel_blocks = 1;
+    constexpr int cols_per_block = 8;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
     } else {
         constexpr bool use_logit_softcap = true;
-        ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
     }
 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
index c1d2dd8d1..d42ddca49 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
@@ -1,7 +1,7 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 
-template // D == head size
+template // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -45,6 +45,18 @@ static __global__ void flash_attn_vec_ext_f32(
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
+        GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+        GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+        GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+        GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+        GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+        GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+        GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+        GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+        GGML_UNUSED(ne2); GGML_UNUSED(ne3);
         NO_DEVICE_CODE;
         return;
     }
@@ -55,16 +67,15 @@ static __global__ void flash_attn_vec_ext_f32(
     constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
     constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
 
-    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
-    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.y              + nb01*ic0;
-    K += nb12*(blockIdx.y / gqa_ratio);
-    V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
+    Q += nb02* blockIdx.z              + nb01*ic0;
+    K += nb12*(blockIdx.z / gqa_ratio);
+    V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
     const half * maskh = (const half   *)  mask + ne11*ic0;
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
     constexpr int nwarps = D / WARP_SIZE;
@@ -115,7 +126,7 @@ static __global__ void flash_attn_vec_ext_f32(
             // Set memory to zero if out of bounds:
             if (ncols > 2 && ic0 + j >= ne01) {
 #pragma unroll
-                for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
                     const int i = i0 + threadIdx.x;
 
                     tmp_q_i32[i] = 0;
@@ -128,7 +139,7 @@ static __global__ void flash_attn_vec_ext_f32(
 
             const float * Q_f = (const float *) (Q + j*nb01);
 #pragma unroll
-            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
                 quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
             }
         }
@@ -141,7 +152,7 @@ static __global__ void flash_attn_vec_ext_f32(
             float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
 
 #pragma unroll
-            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
                 const int i = i0 + threadIdx.x;
 
                 Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
@@ -167,8 +178,7 @@ static __global__ void flash_attn_vec_ext_f32(
 
     float VKQ[ncols] = {0.0f};
 
-    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
-    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         float kqmax_new_arr[ncols];
@@ -268,29 +278,39 @@ static __global__ void flash_attn_vec_ext_f32(
         kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
 
         float dst_val = VKQ[j_VKQ];
-        if (parallel_blocks == 1) {
+        if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
+        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
     }
 
-    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+    if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
+    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
+    GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+    GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
+    GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // FLASH_ATTN_AVAILABLE
 }
 
-template 
+template 
 void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     constexpr int nwarps = D/WARP_SIZE;
-    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
 }
 
 template 
@@ -307,65 +327,48 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
     if (Q->ne[1] == 1) {
-        constexpr int cols_per_block  = 1;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 1;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
         }
         return;
     }
 
     if (Q->ne[1] == 2) {
-        constexpr int cols_per_block  = 2;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 2;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
         }
         return;
     }
 
     if (Q->ne[1] <= 4) {
-        constexpr int cols_per_block  = 4;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 4;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
         }
         return;
     }
 
-    if (Q->ne[1] <= 8) {
-        constexpr int cols_per_block  = 8;
-        constexpr int parallel_blocks = 4;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
-        } else {
-            constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
-        }
-        return;
-    }
-
-    constexpr int cols_per_block  = 8;
-    constexpr int parallel_blocks = 1;
+    constexpr int cols_per_block = 8;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
     } else {
         constexpr bool use_logit_softcap = true;
-        ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
     }
 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
index 8828652fb..bc21b27a0 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -7,14 +7,19 @@
 #include "fattn-wmma-f16.cuh"
 
 #ifdef FP16_MMA_AVAILABLE
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 #include 
+namespace wmma = nvcuda::wmma;
+#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
+#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
+#include 
+namespace wmma = rocwmma;
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 #endif // FP16_MMA_AVAILABLE
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
-template
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+template
+__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -60,19 +65,20 @@ static __global__ void flash_attn_ext_f16(
 
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
-    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
-    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+    const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
 
     static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
     static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
     constexpr int frag_m = ncols == 8 ? 32 : 16;
     constexpr int frag_n = ncols == 8 ?  8 : 16;
     static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
-    typedef nvcuda::wmma::fragment frag_a_K;
-    typedef nvcuda::wmma::fragment frag_a_V;
-    typedef nvcuda::wmma::fragment frag_b;
-    typedef nvcuda::wmma::fragment                      frag_c_KQ;
-    typedef nvcuda::wmma::fragment                          frag_c_VKQ;
+    typedef wmma::fragment frag_a_K;
+    typedef wmma::fragment frag_a_V;
+    typedef wmma::fragment frag_b;
+    typedef wmma::fragment                      frag_c_KQ;
+    typedef wmma::fragment                          frag_c_VKQ;
 
     constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
     constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -84,16 +90,16 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
-    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const float * Q_f   = (const float *) (Q + nb02* blockIdx.z              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.z / gqa_ratio));
+    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
     const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
     const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
 
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
     const half2 slope2 = make_half2(slopef, slopef);
 
@@ -132,9 +138,9 @@ static __global__ void flash_attn_ext_f16(
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
 #pragma unroll
-        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+        for (int i0 = 0; i0 < D/2; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+            if (i0 + warp_size > D/2 && i >= D/2) {
                 break;
             }
             VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
@@ -146,9 +152,9 @@ static __global__ void flash_attn_ext_f16(
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
 #pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+        for (int i0 = 0; i0 < D; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
+            if (i0 + warp_size > D && i >= D) {
                 break;
             }
             KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
@@ -162,34 +168,34 @@ static __global__ void flash_attn_ext_f16(
     for (int i0 = 0; i0 < D; i0 += 16) {
 #pragma unroll
         for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+            wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
         }
     }
 
     __syncthreads();
 
     // Iterate over ne11 == previous tokens:
-    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
         // Calculate tile of KQ:
 #pragma unroll
         for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
             frag_c_KQ KQ_c[ncols/frag_n];
 #pragma unroll
             for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+                wmma::fill_fragment(KQ_c[j], static_cast(0.0f));
             }
 #pragma unroll
             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
                 frag_a_K K_a;
-                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+                wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+                    wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
                 }
             }
 #pragma unroll
             for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+                wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
             }
         }
 
@@ -202,27 +208,27 @@ static __global__ void flash_attn_ext_f16(
             const int j = j0 + threadIdx.y;
 
             if (std::is_same::value) {
-                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+                float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+                    KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
 
                     if (use_logit_softcap) {
-                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+                        KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
                     }
                 }
 
                 float KQ_max_new = KQ_max_f[j0/nwarps];
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
-                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+                    KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
                 }
-                KQ_max_new = warp_reduce_max(KQ_max_new);
+                KQ_max_new = warp_reduce_max(KQ_max_new);
 
                 const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
                 KQ_max_scale_f[j0/nwarps] = expf(diff);
@@ -233,48 +239,48 @@ static __global__ void flash_attn_ext_f16(
 
                 float KQ_rowsum_add = 0.0f;
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
-                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+                    const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
+                    KQ_f_tmp[k0/warp_size] = expf(diff);
                     if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+                        KQ_f_tmp[k0/warp_size] = 0.0f;
                     }
-                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
-                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+                    KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
+                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
                 }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
 
                 // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
                 KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
             } else {
-                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+                    KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
 
                     if (use_logit_softcap) {
                         // There is no dedicated tangens hyperbolicus function for half2.
-                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
-                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
-                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+                        KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
+                        KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
+                                               /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
 
-                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
+                        KQ2_tmp[k0/warp_size] *= logit_softcap_2;
                     }
                 }
 
                 half2 KQ_max_new = KQ_max_h2[j0/nwarps];
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
-                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                    KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
                 }
-                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
                 const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
                 KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
                 const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
@@ -283,17 +289,17 @@ static __global__ void flash_attn_ext_f16(
 
                 half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
-                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+                    const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
+                    KQ2_tmp[k0/warp_size] = h2exp(diff);
                     const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
-                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
-                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
-                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+                    *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
+                    KQ_rowsum_add += KQ2_tmp[k0/warp_size];
+                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
                 }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
 
                 // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
                 KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
@@ -308,7 +314,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
             for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
-                nvcuda::wmma::load_matrix_sync(
+                wmma::load_matrix_sync(
                     KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
                     KQ + j0*(kqar*kqs_padded) + k,
                     kqar*kqs_padded);
@@ -320,7 +326,7 @@ static __global__ void flash_attn_ext_f16(
         for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
 #pragma unroll
             for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+                wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast(0.0f));
             }
 
 #pragma unroll
@@ -328,10 +334,10 @@ static __global__ void flash_attn_ext_f16(
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
 
                 frag_a_V v_a;
-                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+                wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+                    wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
                 }
             }
         }
@@ -343,10 +349,10 @@ static __global__ void flash_attn_ext_f16(
         for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
 #pragma unroll
             for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync(
+                wmma::store_matrix_sync(
                     KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
                     VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
-                    D_padded, nvcuda::wmma::mem_col_major);
+                    D_padded, wmma::mem_col_major);
             }
         }
 
@@ -364,9 +370,9 @@ static __global__ void flash_attn_ext_f16(
             }
 
 #pragma unroll
-            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < D/2; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
-                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                if (i0 + warp_size > D/2 && i >= D/2) {
                     break;
                 }
 
@@ -388,7 +394,7 @@ static __global__ void flash_attn_ext_f16(
         if (ic0 + j_VKQ >= ne01) {
             return;
         }
-        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
 
         float KQ_rowsum_j;
         if (std::is_same::value) {
@@ -398,19 +404,19 @@ static __global__ void flash_attn_ext_f16(
         }
 
 #pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+        for (int i0 = 0; i0 < D; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
+            if (i0 + warp_size > D && i >= D) {
                 break;
             }
             float dst_val = VKQ[j_VKQ*D_padded + i];
-            if (parallel_blocks == 1) {
+            if (gridDim.y == 1) {
                 dst_val /= KQ_rowsum_j;
             }
-            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+            dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
         }
 
-        if (parallel_blocks == 1 || threadIdx.x != 0) {
+        if (gridDim.y == 1 || threadIdx.x != 0) {
             continue;
         }
 
@@ -421,11 +427,21 @@ static __global__ void flash_attn_ext_f16(
             dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
         }
         dst_meta_val.y = KQ_rowsum_j;
-        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+        dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
     }
 #else
-   NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
+    GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
+    GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
+    GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
+    GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    NO_DEVICE_CODE;
+#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
 }
 
 constexpr int get_max_power_of_2(int x) {
@@ -455,59 +471,26 @@ static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
 template 
 void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
 
     constexpr int nwarps = 4;
 
     constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
-    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
-    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
 
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
-    if (4*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 4;
-        fattn_kernel_t fattn_kernel;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        } else {
-            constexpr bool use_logit_softcap = true;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        }
-        launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true);
-        return;
-    }
-    if (2*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 2;
-        fattn_kernel_t fattn_kernel;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        } else {
-            constexpr bool use_logit_softcap = true;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        }
-        launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true);
-        return;
-    }
-    constexpr int parallel_blocks = 1;
     fattn_kernel_t fattn_kernel;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
         fattn_kernel = flash_attn_ext_f16<
-            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
     } else {
         constexpr bool use_logit_softcap = true;
         fattn_kernel = flash_attn_ext_f16<
-            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
     }
-    launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
 }
 
 void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -515,6 +498,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
     const ggml_tensor * Q   = dst->src[0];
 
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+    const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
 
     if (prec != GGML_PREC_DEFAULT) {
         if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
@@ -571,7 +555,8 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
         return;
     }
 
-    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+    if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
         constexpr int cols_per_block = 8;
         switch (Q->ne[0]) {
             case 64:
@@ -592,6 +577,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
         }
         return;
     }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 
     if (Q->ne[1] <= 32) {
         constexpr int cols_per_block = 16;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
index b1becccb4..7a2d1e453 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
@@ -250,10 +250,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
-    // On AMD the tile kernels perform poorly, use the vec kernel instead:
-    if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
+    if (GGML_CUDA_CC_IS_AMD(cc)) {
+#if defined(GGML_HIP_ROCWMMA_FATTN)
+        if (fp16_mma_available(cc)) {
+            ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+            return;
+        }
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
+
+        // On AMD the tile kernels perform poorly, use the vec kernel instead:
         if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
         } else {
@@ -273,13 +281,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     if (!fp16_mma_available(cc)) {
         if (prec == GGML_PREC_DEFAULT) {
-            if (Q->ne[1] <= 8) {
+            if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
                 ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             } else {
                 ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
             }
         } else {
-            if (Q->ne[1] <= 8) {
+            if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
                 ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
             } else {
                 ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
@@ -288,21 +296,21 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         return;
     }
 
-    const int gqa_ratio = Q->ne[2] / K->ne[2];
-    const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
-        K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
-    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
+    const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
+    const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
+    const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
+    const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
+    if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
-            return;
-        } else if(Q->ne[0] <= 128) {
+        } else {
             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
-            return;
         }
+        return;
     }
 
     // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
-    if (cc == GGML_CUDA_CC_VOLTA) {
+    if (fp16_mma_available(cc) && !new_mma_available(cc)) {
         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
         return;
     }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
index c7a957c82..a44788db2 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -31,12 +31,14 @@
 #include "ggml-cuda/rope.cuh"
 #include "ggml-cuda/scale.cuh"
 #include "ggml-cuda/softmax.cuh"
+#include "ggml-cuda/ssm-conv.cuh"
+#include "ggml-cuda/ssm-scan.cuh"
 #include "ggml-cuda/sum.cuh"
 #include "ggml-cuda/sumrows.cuh"
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
-#include "ggml-cuda/wkv6.cuh"
+#include "ggml-cuda/wkv.cuh"
 #include "ggml-cuda/gla.cuh"
 #include "ggml.h"
 
@@ -262,9 +264,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
                       id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
                       device_vmm ? "yes" : "no", prop.warpSize);
 #elif defined(GGML_USE_MUSA)
-        // TODO: refine the .cc to reflect MUSA's actual CC capabilities
+        // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
+        info.devices[id].warp_size = 32;
         info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
-        info.devices[id].cc = 100*prop.major + 10*prop.minor;
+        info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
+        info.devices[id].cc += prop.minor * 0x10;
         GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
                         id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 #else
@@ -541,12 +545,12 @@ static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
     return ctx->dev_ptr;
 }
 
-static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
     ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 
     if (tensor->view_src != NULL) {
         assert(tensor->view_src->buffer->buft == buffer->buft);
-        return;
+        return GGML_STATUS_SUCCESS;
     }
 
     if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
@@ -559,6 +563,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
             CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
         }
     }
+    return GGML_STATUS_SUCCESS;
 }
 
 static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
@@ -794,7 +799,7 @@ static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buff
     GGML_UNUSED(buffer);
 }
 
-static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
     GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
 
     ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
@@ -840,6 +845,7 @@ static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buf
         }
     }
     tensor->extra = extra;
+    return GGML_STATUS_SUCCESS;
 }
 
 static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
@@ -1187,11 +1193,39 @@ static void ggml_cuda_op_mul_mat_cublas(
     // ldc == nrows of the matrix that cuBLAS writes into
     int64_t ldc = id == ctx.device ? ne0 : row_diff;
 
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
+    const int cc = ggml_cuda_info().devices[id].cc;
 
     const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
 
-    if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
+    if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
+        ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id));
+        if (src1->type != GGML_TYPE_BF16) {
+            const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
+            GGML_ASSERT(to_bf16_cuda != nullptr);
+            size_t ne = src1_ncols*ne10;
+            src1_as_bf16.alloc(ne);
+            to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream);
+        }
+        const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
+        const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i;
+        ggml_cuda_pool_alloc dst_bf16(ctx.pool(id), row_diff*src1_ncols);
+
+        const float alpha_f32 = 1.0f;
+        const float beta_f32  = 0.0f;
+
+        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+        CUBLAS_CHECK(
+            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+                    row_diff, src1_ncols, ne10,
+                    &alpha_f32,  src0_ptr,       CUDA_R_16BF, ne00,
+                                 src1_ptr,       CUDA_R_16BF, ne10,
+                    &beta_f32,   dst_bf16.get(), CUDA_R_16BF, ldc,
+                    CUBLAS_COMPUTE_32F,
+                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
+        to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+    } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
         ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id));
         if (src0->type != GGML_TYPE_F16) {
@@ -1215,7 +1249,7 @@ static void ggml_cuda_op_mul_mat_cublas(
 
         CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
 
-        if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
+        if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
             const float alpha = 1.0f;
             const float beta = 0.0f;
             CUBLAS_CHECK(
@@ -1758,7 +1792,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         beta  = &beta_f32;
     }
 
-    if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) {
+    int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
+    if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
         cu_compute_type = CUBLAS_COMPUTE_32F;
         alpha = &alpha_f32;
         beta  = &beta_f32;
@@ -1835,7 +1871,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     }
 #endif
 
-    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+    if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
         to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
     }
@@ -2148,6 +2184,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
             break;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(dst)) {
+                case GGML_UNARY_OP_ABS:
+                    ggml_cuda_op_abs(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_SGN:
+                    ggml_cuda_op_sgn(ctx, dst);
+                    break;
                 case GGML_UNARY_OP_NEG:
                     ggml_cuda_op_neg(ctx, dst);
                     break;
@@ -2191,6 +2233,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GROUP_NORM:
             ggml_cuda_op_group_norm(ctx, dst);
             break;
+        case GGML_OP_L2_NORM:
+            ggml_cuda_op_l2_norm(ctx, dst);
+            break;
         case GGML_OP_CONCAT:
             ggml_cuda_op_concat(ctx, dst);
             break;
@@ -2248,6 +2293,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_CLAMP:
             ggml_cuda_op_clamp(ctx, dst);
             break;
+        case GGML_OP_LOG:
+            ggml_cuda_op_log(ctx, dst);
+            break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -2284,6 +2332,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SUM_ROWS:
             ggml_cuda_op_sum_rows(ctx, dst);
             break;
+        case GGML_OP_SSM_CONV:
+            ggml_cuda_op_ssm_conv(ctx, dst);
+            break;
+        case GGML_OP_SSM_SCAN:
+            ggml_cuda_op_ssm_scan(ctx, dst);
+            break;
         case GGML_OP_ARGSORT:
             ggml_cuda_op_argsort(ctx, dst);
             break;
@@ -2301,6 +2355,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GATED_LINEAR_ATTN:
             ggml_cuda_op_gated_linear_attn(ctx, dst);
             break;
+        case GGML_OP_RWKV_WKV7:
+            ggml_cuda_op_rwkv_wkv7(ctx, dst);
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
             break;
@@ -2568,7 +2625,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
         for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
             if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
                 char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
-                cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
+                *(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
                 CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
             }
         }
@@ -2607,13 +2664,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
 
 static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
 
+#if CUDART_VERSION >= 12000
     cudaGraphExecUpdateResultInfo result_info;
-#ifdef __HIP_PLATFORM_AMD__
-    hipGraphNode_t errorNode;
-    hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
-#else
     cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
-#endif
+#else
+    cudaGraphNode_t errorNode;
+    cudaGraphExecUpdateResult result_info;
+    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
+#endif // CUDART_VERSION >= 12000
+
     if (stat == cudaErrorGraphExecUpdateFailure) {
 #ifndef NDEBUG
         GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
@@ -2968,6 +3027,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_ABS:
+                case GGML_UNARY_OP_SGN:
                 case GGML_UNARY_OP_NEG:
                 case GGML_UNARY_OP_STEP:
                 case GGML_UNARY_OP_GELU:
@@ -3071,6 +3132,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
                     return true;
                 }
@@ -3150,10 +3214,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 return false;
             } break;
         case GGML_OP_SILU_BACK:
-            return ggml_is_contiguous(op->src[0]);
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
             break;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_L2_NORM:
             return true;
         case GGML_OP_RMS_NORM_BACK:
             return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
@@ -3174,6 +3239,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SIN:
         case GGML_OP_COS:
         case GGML_OP_CLAMP:
+        case GGML_OP_LOG:
+        case GGML_OP_SSM_SCAN:
+        case GGML_OP_SSM_CONV:
             return true;
         case GGML_OP_CONT:
             return op->src[0]->type != GGML_TYPE_BF16;
@@ -3201,6 +3269,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_GROUP_NORM:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_UPSCALE:
+            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
         case GGML_OP_PAD:
         case GGML_OP_UNPAD:
         case GGML_OP_ARANGE:
@@ -3208,11 +3277,22 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_GATED_LINEAR_ATTN:
+        case GGML_OP_RWKV_WKV7:
             return true;
         case GGML_OP_FLASH_ATTN_EXT: {
 #ifndef FLASH_ATTN_AVAILABLE
             return false;
 #endif // FLASH_ATTN_AVAILABLE
+            if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
+                // different head sizes of K and V are not supported yet
+                return false;
+            }
+            if (op->src[0]->ne[0] == 192) {
+                return false;
+            }
+            if (op->src[0]->ne[3] != 1) {
+                return false;
+            }
             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
                 return false;
             }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
index 9206bfeba..2af63355a 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
@@ -26,6 +26,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
     asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
         : "=r"(ret) : "r"(x));
 #else
+    GGML_UNUSED(x);
     NO_DEVICE_CODE;
 #endif // defined(NEW_MMA_AVAILABLE)
     return ret;
@@ -178,6 +179,7 @@ namespace ggml_cuda_mma {
             : "l"(xs));
 #else
         load_generic(xs0, stride);
+        GGML_UNUSED(t);
 #endif // NEW_MMA_AVAILABLE
     }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
index 933d945c8..b36b43d54 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
@@ -27,8 +27,8 @@ void ggml_cuda_op_mul_mat_q(
     // The stream-k decomposition is only faster for recent NVIDIA GPUs.
     // Also its fixup needs to allocate a temporary buffer in the memory pool.
     // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
-    const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
-        cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
+    const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
+        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
     const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
 
     switch (src0->type) {
@@ -145,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
     return true;
 #endif //GGML_CUDA_FORCE_MMQ
 
-    if (cc < GGML_CUDA_CC_OFFSET_AMD) {
+    if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
         return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
index 66ce2bc95..532358018 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
@@ -90,7 +90,7 @@ struct tile_x_sizes {
 
 static int get_mmq_x_max_host(const int cc) {
     return new_mma_available(cc) ? 128 :
-        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
+        GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
 #ifdef GGML_CUDA_FORCE_MMQ
             128                     : 64;
 #else
@@ -109,9 +109,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
 
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 #ifdef GGML_CUDA_FORCE_MMQ
-    return MMQ_DP4A_MAX_BATCH_SIZE;
-#else // GGML_CUDA_FORCE_MMQ
     return 128;
+#else // GGML_CUDA_FORCE_MMQ
+    return MMQ_DP4A_MAX_BATCH_SIZE;
 #endif // GGML_CUDA_FORCE_MMQ
 #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
@@ -123,8 +123,8 @@ static constexpr __device__ int get_mmq_x_max_device() {
 }
 
 static int get_mmq_y_host(const int cc) {
-    return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
-        (ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
+    return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
+        ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
 }
 
 static constexpr __device__ int get_mmq_y_device() {
@@ -945,7 +945,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
         }
     }
 #else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
@@ -1024,7 +1024,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
     }
 
 #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -1035,19 +1035,34 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
             for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
                 const int i = i0 + threadIdx.x;
 
-                if (k01 < WARP_SIZE/2) {
-                    constexpr int ns = 2;
-                    sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
-                        &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
-                        &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
-                        &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
-                } else {
-                    constexpr int ns = 1;
-                    sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
-                        &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
-                        &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
-                        &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
-                }
+                constexpr int ns = 2;
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
+                    &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                    &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                    &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+            }
+        }
+    }
+
+    // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
+    // As a workaround 2 separate loops are used instead.
+#pragma unroll
+    for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                constexpr int ns = 1;
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
+                    &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                    &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                    &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
             }
         }
     }
@@ -1176,7 +1191,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         }
     }
 #else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
@@ -1253,7 +1268,7 @@ template  static __device__ __forceinlin
         const float d = bxi->d;
 
 #pragma unroll
-        for (int l = 0; l < sizeof(int); ++l) {
+        for (int l = 0; l < int(sizeof(int)); ++l) {
             x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
         }
 #else
@@ -1376,7 +1391,7 @@ template  static __device__ __forceinlin
         const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
 
 #pragma unroll
-        for (int l = 0; l < sizeof(int); ++l) {
+        for (int l = 0; l < int(sizeof(int)); ++l) {
             x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
         }
     }
@@ -1517,7 +1532,7 @@ template  static __device__ __forceinlin
         const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
 
 #pragma unroll
-        for (int l = 0; l < sizeof(int); ++l) {
+        for (int l = 0; l < int(sizeof(int)); ++l) {
             x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
         }
     }
@@ -1810,7 +1825,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         }
     }
 #else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
@@ -2570,6 +2585,8 @@ static __device__ void mul_mat_q_process_tile(
     } else {
         write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
     }
+
+    GGML_UNUSED(ne00); GGML_UNUSED(ne10);
 }
 
 
@@ -2695,7 +2712,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
         const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
 
         // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
-        if (it != blockIdx.x || jt != blockIdx.y) {
+        if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) {
             continue;
         }
 
@@ -2772,14 +2789,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
     const int shmem = mmq_get_shmem(mmq_x, mmq_y, cc);
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
     static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
     if (!shmem_limit_raised[id]) {
         CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
         CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
         shmem_limit_raised[id] = true;
     }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
 
     const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
     const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
@@ -2825,14 +2842,13 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 template 
 void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
     const int id    = ggml_cuda_get_device();
-    const int nsm   = ggml_cuda_info().devices[id].nsm;
     const int cc    = ggml_cuda_info().devices[id].cc;
     const int smpbo = ggml_cuda_info().devices[id].smpbo;
 
     const int mmq_x_max = get_mmq_x_max_host(cc);
     const int mmq_y = get_mmq_y_host(cc);
     const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
-    const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
+    const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
 
     int mmq_x_best  = 0;
     int nparts_best = INT_MAX;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
index f89ed03b5..b39961cd1 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
@@ -29,7 +29,7 @@ static __global__ void mul_mat_vec(
         __syncthreads();
     }
 
-    float sumf;
+    float sumf = 0.0f;
 
     if constexpr (std::is_same::value) {
         const half2 * x2 = (const half2 *) x;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
index 23ae7abcb..eef8585a7 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
@@ -47,11 +47,89 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
         1;
 }
 
+enum mmvq_parameter_table_id {
+    MMVQ_PARAMETERS_GENERIC = 0,
+    MMVQ_PARAMETERS_GCN,
+    MMVQ_PARAMETERS_RDNA2
+};
+
+static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
+#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
+    return MMVQ_PARAMETERS_RDNA2;
+#elif defined(GCN) || defined(CDNA)
+    return MMVQ_PARAMETERS_GCN;
+#else
+    return MMVQ_PARAMETERS_GENERIC;
+#endif
+}
+
+static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
+    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+        return MMVQ_PARAMETERS_RDNA2;
+    }
+    if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
+        return MMVQ_PARAMETERS_GCN;
+    }
+    return MMVQ_PARAMETERS_GENERIC;
+}
+
+static constexpr __host__ __device__ int calc_nwarps(int ncols_y,  mmvq_parameter_table_id table_id) {
+    if (table_id == MMVQ_PARAMETERS_GENERIC) {
+        switch (ncols_y) {
+            case 1:
+            case 2:
+            case 3:
+            case 4:
+                return 4;
+            case 5:
+            case 6:
+            case 7:
+            case 8:
+                return 2;
+            default:
+                return 1;
+        }
+    } else if (table_id == MMVQ_PARAMETERS_GCN) {
+        switch (ncols_y) {
+            case 1:
+            case 2:
+            case 3:
+            case 4:
+                return 2;
+            case 5:
+            case 6:
+            case 7:
+            case 8:
+            default:
+                return 1;
+        }
+    }
+    return 1;
+}
+
+static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
+    if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
+        switch (ncols_y) {
+            case 1:
+                return 1;
+            case 2:
+            case 3:
+            case 4:
+            case 5:
+            case 6:
+            case 7:
+            case 8:
+                return 2;
+            default:
+                return 1;
+        }
+    }
+    return 1;
+}
+
 template 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 // tell the compiler to use as many registers as it wants, see nwarps definition below
-__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -59,25 +137,21 @@ static __global__ void mul_mat_vec_q(
     constexpr int qk  = ggml_cuda_type_traits::qk;
     constexpr int qi  = ggml_cuda_type_traits::qi;
     constexpr int vdr = get_vdr_mmvq(type);
+    constexpr mmvq_parameter_table_id table_id = get_device_table_id();
+    constexpr int nwarps = calc_nwarps(ncols_y, table_id);
+    constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
-    constexpr int nwarps              = 1;
-    constexpr int rows_per_cuda_block = 1;
-#else
-    constexpr int nwarps              = ncols_y <= 4 ? 4 : 2;
-    constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
-
-    const     int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    const     int tid = warp_size*threadIdx.y + threadIdx.x;
     const     int row0 = rows_per_cuda_block*blockIdx.x;
     const     int blocks_per_row_x = ncols_x / qk;
     const     int blocks_per_col_y = nrows_y / QK8_1;
-    constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
+    constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 
-// partial sum for each thread
-    float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
+    // partial sum for each thread
+    float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
 
     const block_q8_1 * y = (const block_q8_1 *) vy;
 
@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
         }
     }
 
-    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
+    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
     if (threadIdx.y > 0) {
 #pragma unroll
         for (int j = 0; j < ncols_y; ++j) {
@@ -120,13 +194,22 @@ static __global__ void mul_mat_vec_q(
             for (int l = 0; l < nwarps-1; ++l) {
                 tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
             }
-            tmp[j][i] = warp_reduce_sum(tmp[j][i]);
+            tmp[j][i] = warp_reduce_sum(tmp[j][i]);
         }
 
-        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
+        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
             dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
         }
     }
+
+    GGML_UNUSED(nrows_x);
+}
+
+static std::pair calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
+    const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
+    const dim3 block_nums(nblocks, 1, 1);
+    const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
+    return {block_nums, block_dims};
 }
 
 template 
@@ -137,65 +220,67 @@ static void mul_mat_vec_q_cuda(
     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
     GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
 
-    int id = ggml_cuda_get_device();
-
-    int64_t nwarps = 1;
-    int64_t rows_per_cuda_block = 1;
-
-    if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
-        switch(ncols_y) {
-            case 1:
-                nwarps = 4;
-                rows_per_cuda_block = 1;
-                break;
-            case 2:
-            case 3:
-            case 4:
-                nwarps = 4;
-                rows_per_cuda_block = 2;
-                break;
-            case 5:
-            case 6:
-            case 7:
-            case 8:
-                nwarps = 2;
-                rows_per_cuda_block = 2;
-                break;
-            default:
-                GGML_ABORT("fatal error");
-                break;
-        }
-    }
-
-    const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
-    const dim3 block_nums(nblocks, 1, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+    const int device = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[device].warp_size;
+    const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
 
     switch (ncols_y) {
         case 1:
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 1;
+            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 2:
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 2;
+            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 3:
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 3;
+            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 4:
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 4;
+            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 5:
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 5;
+            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 6:
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 6;
+            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 7:
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 7;
+            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 8:
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 8;
+            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         default:
             GGML_ABORT("fatal error");
             break;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
index f127616ed..0020dbcec 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
@@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
     }
 }
 
+// template 
+// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
+//     const int row = blockIdx.x*blockDim.y + threadIdx.y;
+//     const int tid = threadIdx.x;
+
+//     float tmp = 0.0f; // partial sum for thread in warp
+
+//     for (int col = tid; col < ncols; col += block_size) {
+//         const float xi = x[row*ncols + col];
+//         tmp += xi * xi;
+//     }
+
+//     // sum up partial sums
+//     tmp = warp_reduce_sum(tmp);
+//     if (block_size > WARP_SIZE) {
+//         __shared__ float s_sum[32];
+//         int warp_id = threadIdx.x / WARP_SIZE;
+//         int lane_id = threadIdx.x % WARP_SIZE;
+//         if (lane_id == 0) {
+//             s_sum[warp_id] = tmp;
+//         }
+//         __syncthreads();
+//         tmp = s_sum[lane_id];
+//         tmp = warp_reduce_sum(tmp);
+//     }
+
+//     // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
+//     const float scale = rsqrtf(fmaxf(tmp, eps * eps));
+
+//     for (int col = tid; col < ncols; col += block_size) {
+//         dst[row*ncols + col] = scale * x[row*ncols + col];
+//     }
+// }
+
+template 
+static __global__ void l2_norm_f32(
+        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+        const int64_t stride_sample, const float eps) {
+    const int nrows     = gridDim.x;
+    const int nchannels = gridDim.y;
+
+    const int row       = blockIdx.x;
+    const int channel   = blockIdx.y;
+    const int sample    = blockIdx.z;
+    const int tid       = threadIdx.x;
+
+    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
+    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int col = tid; col < ncols; col += block_size) {
+        const float xi = x[col];
+        tmp += xi * xi;
+    }
+
+    // sum up partial sums
+    tmp = warp_reduce_sum(tmp);
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
+        __shared__ float s_sum[32];
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
+    const float scale = rsqrtf(fmaxf(tmp, eps * eps));
+
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[col] = scale * x[col];
+    }
+}
+
 static void norm_f32_cuda(
         const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
         const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
@@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
     }
 }
 
+static void l2_norm_f32_cuda(
+        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+    const dim3 blocks_num(nrows, nchannels, nsamples);
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        l2_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+    }
+}
+
 void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *) src0->data;
@@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
 
     rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
 }
+
+void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
+
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
+
+    l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
index d63d34380..706a5660a 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
@@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
index b4b874093..7d45a7e19 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
@@ -14,7 +14,7 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
         nidx +
         blockIdx.y * ne0 +
         blockIdx.z * ne0 * gridDim.y;
-    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
+    if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
         int offset_src =
             nidx +
             blockIdx.y * ne00 +
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu
new file mode 100644
index 000000000..f63757196
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu
@@ -0,0 +1,148 @@
+#include "ssm-conv.cuh"
+
+template 
+static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
+                                    const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
+                                    float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
+                                    const int64_t n_t) {
+    GGML_UNUSED(src0_nb0);
+    const int tid  = threadIdx.x;
+    const int bidx = blockIdx.x;
+    const int bidy = blockIdx.y;
+
+    const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
+    const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
+    float *       y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
+
+    const int stride_x = src0_nb1 / sizeof(float);
+    const int stride_w = src1_nb1 / sizeof(float);
+    const int stride_y = dst_nb1 / sizeof(float);
+
+    float x[d_conv] = { 0.0f };
+    float w[d_conv] = { 0.0f };
+
+#pragma unroll
+    for (size_t j = 0; j < d_conv; j++) {
+        w[j] = w_block[tid * stride_w + j];
+    }
+
+    for (int64_t i = 0; i < n_t; i++) {
+        float sumf = 0.0f;
+
+        if (i == 0) {
+            for (size_t j = 0; j < d_conv; j++) {
+                x[j] = x_block[tid * stride_x + j];
+            }
+        } else {
+            x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
+        }
+
+#pragma unroll
+        for (size_t j = 0; j < d_conv; j++) {
+            sumf += x[(i + j) % d_conv] * w[j];
+        }
+        y_block[i * stride_y + tid] = sumf;
+    }
+}
+
+template 
+static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
+                                               const int src0_nb0, const int src0_nb1, const int src0_nb2,
+                                               const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
+                                               const int dst_nb1, const int dst_nb2, const int64_t n_t) {
+    const int tid  = threadIdx.x;
+    const int bidx = blockIdx.x;
+    const int bidy = blockIdx.y;
+    const int bidz = blockIdx.z;
+
+    const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
+                                             bidz * split_n_t * src0_nb0);
+    const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
+    float *       y_block =
+        (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
+
+    const int stride_x = src0_nb1 / sizeof(float);
+    const int stride_w = src1_nb1 / sizeof(float);
+    const int stride_y = dst_nb1 / sizeof(float);
+
+    float x[d_conv] = { 0.0f };
+    float w[d_conv] = { 0.0f };
+
+#pragma unroll
+    for (size_t j = 0; j < d_conv; j++) {
+        w[j] = w_block[tid * stride_w + j];
+    }
+
+#pragma unroll
+    for (int64_t i = 0; i < split_n_t; i++) {
+        if (bidz * split_n_t + i < n_t) {
+            float sumf = 0.0f;
+
+            if (i == 0) {
+                for (size_t j = 0; j < d_conv; j++) {
+                    x[j] = x_block[tid * stride_x + j];
+                }
+            } else {
+                x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
+            }
+
+#pragma unroll
+            for (size_t j = 0; j < d_conv; j++) {
+                sumf += x[(i + j) % d_conv] * w[j];
+            }
+            y_block[i * stride_y + tid] = sumf;
+        }
+    }
+}
+
+static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
+                              const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
+                              const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
+                              const int64_t n_s, cudaStream_t stream) {
+    const int threads = 128;
+    GGML_ASSERT(nr % threads == 0);
+
+    if (n_t <= 32) {
+        const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
+        if (nc == 4) {
+            ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+                                                                     dst, dst_nb0, dst_nb1, dst_nb2, n_t);
+        } else {
+            GGML_ABORT("Only support kernel size = 4  now.");
+        }
+    } else {
+        if (nc == 4) {
+            const int64_t split_n_t = 32;
+            dim3          blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
+            ssm_conv_long_token_f32<<>>(
+                src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
+        } else {
+            GGML_ABORT("Only support kernel size = 4 right now.");
+        }
+    }
+}
+
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];  // conv_x
+    const struct ggml_tensor * src1 = dst->src[1];  // conv1d.weight
+
+    const int64_t nc  = src1->ne[0];                // d_conv
+    const int64_t nr  = src0->ne[1];                // d_inner
+    const int64_t n_t = dst->ne[1];                 // tokens per sequence
+    const int64_t n_s = dst->ne[2];                 // number of sequences in the batch
+
+    GGML_ASSERT(dst->ne[0] == nr);
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float *       dst_d  = (float *) dst->data;
+    cudaStream_t  stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
+                      dst->nb[2], nc, nr, n_t, n_s, stream);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cuh
new file mode 100644
index 000000000..8e6c1f00b
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu
new file mode 100644
index 000000000..37ee208c0
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu
@@ -0,0 +1,153 @@
+#include "ssm-scan.cuh"
+
+template 
+__global__ void __launch_bounds__(splitD, 2)
+    ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
+                 const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
+                 const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
+                 const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
+                 const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
+                 float * __restrict__ dst, const int64_t L) {
+    GGML_UNUSED(src1_nb0);
+    GGML_UNUSED(src2_nb0);
+    const int bidx = blockIdx.x;  // split along B
+    const int bidy = blockIdx.y;  // split along D
+    const int tid  = threadIdx.x;
+    const int wid  = tid / 32;
+    const int wtid = tid % 32;
+
+    extern __shared__ float smem[];
+    const int               stride_sA  = N + 1;
+    const int               stride_ss0 = N + 1;
+    float *                 smem_A     = smem;
+    float *                 smem_s0    = smem_A + splitD * stride_sA;
+
+    const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
+    const float * x_block  = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
+    const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
+    const float * A_block  = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
+    const float * B_block  = (const float *) ((const char *) src4 + (bidx * src4_nb2));
+    const float * C_block  = (const float *) ((const char *) src5 + (bidx * src5_nb2));
+    float *       y_block  = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
+    float *       s_block  = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
+
+    const int stride_s0 = src0_nb1 / sizeof(float);
+    const int stride_x  = src1_nb1 / sizeof(float);
+    const int stride_dt = src2_nb1 / sizeof(float);
+    const int stride_A  = src3_nb1 / sizeof(float);
+    const int stride_B  = src4_nb1 / sizeof(float);
+    const int stride_C  = src5_nb1 / sizeof(float);
+    const int stride_s  = stride_s0;
+    const int stride_y  = stride_x;
+
+    // can N not be 16? for example 32?
+    if (N == 16) {
+#pragma unroll
+        for (size_t i = 0; i < splitD / 4; i += 2) {
+            float value = A_block[(wid * warpSize + i) * stride_A + wtid];
+            // todo: bank conflict
+            // I am always confused with how to use the swizzling method to solve
+            // bank conflit. Hoping somebody can tell me.
+            smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
+        }
+#pragma unroll
+        for (size_t i = 0; i < splitD / 4; i += 2) {
+            float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
+            smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
+        }
+    }
+
+    __syncthreads();
+
+    for (int64_t i = 0; i < L; i++) {
+        float dt_soft_plus = dt_block[i * stride_dt + tid];
+        if (dt_soft_plus <= 20.0f) {
+            dt_soft_plus = log1pf(exp(dt_soft_plus));
+        }
+        float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
+        float sumf = 0.0f;
+#pragma unroll
+        for (size_t j = 0; j < N; j++) {
+            float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
+                          (B_block[i * stride_B + j] * x_dt);
+            sumf += state * C_block[i * stride_C + j];
+            if (i == L - 1) {
+                s_block[tid * stride_s + j] = state;
+            } else {
+                smem_s0[tid * stride_ss0 + j] = state;
+            }
+        }
+        __syncthreads();
+        y_block[i * stride_y + tid] = sumf;
+    }
+}
+
+static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
+                              const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
+                              const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
+                              const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
+                              const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
+                              float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
+                              cudaStream_t stream) {
+    const int threads = 128;
+    // todo: consider D cannot be divided,does this situation exist?
+    GGML_ASSERT(D % threads == 0);
+    const dim3 blocks(B, (D + threads - 1) / threads, 1);
+    const int  smem_size = (threads * (N + 1) * 2) * sizeof(float);
+    if (N == 16) {
+        ssm_scan_f32<128, 16><<>>(
+            src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
+            src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
+    } else {
+        GGML_ABORT("doesn't support N!=16.");
+    }
+}
+
+void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];  // s
+    const struct ggml_tensor * src1 = dst->src[1];  // x
+    const struct ggml_tensor * src2 = dst->src[2];  // dt
+    const struct ggml_tensor * src3 = dst->src[3];  // A
+    const struct ggml_tensor * src4 = dst->src[4];  // B
+    const struct ggml_tensor * src5 = dst->src[5];  // C
+
+    //   const int64_t d_state = src0->ne[0];
+    //   const int64_t d_inner = src0->ne[1];
+    //   const int64_t l = src1->ne[1];
+    //   const int64_t b = src0->ne[2];
+
+    const int64_t nc  = src0->ne[0];  // d_state
+    const int64_t nr  = src0->ne[1];  // d_inner
+    const int64_t n_t = src1->ne[1];  // number of tokens per sequence
+    const int64_t n_s = src0->ne[2];  // number of sequences in the batch
+
+    GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+    GGML_ASSERT(src2->nb[0] == sizeof(float));
+    GGML_ASSERT(src3->nb[0] == sizeof(float));
+    GGML_ASSERT(src4->nb[0] == sizeof(float));
+    GGML_ASSERT(src5->nb[0] == sizeof(float));
+    // required for the dot product between s and C
+    GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
+    // required for per-sequence offsets for states
+    GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
+    // required to get correct offset for state destination (i.e. src1->nb[3])
+    GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    const float * src2_d = (const float *) src2->data;
+    const float * src3_d = (const float *) src3->data;
+    const float * src4_d = (const float *) src4->data;
+    const float * src5_d = (const float *) src5->data;
+    float *       dst_d  = (float *) dst->data;
+    cudaStream_t  stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
+                      src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
+                      src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cuh
new file mode 100644
index 000000000..ee078f5eb
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu
index 6b21f407d..ec5773e01 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu
@@ -1,305 +1,213 @@
 #include "unary.cuh"
 
-static __global__ void neg_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-
-    dst[i] = -x[i];
+static __device__ __forceinline__ float op_abs(float x) {
+    return fabsf(x);
 }
 
-static __global__ void step_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-
-    dst[i] = x[i] > 0.0f;
+static __device__ __forceinline__ float op_sgn(float x) {
+    return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f)));
 }
 
-static __global__ void gelu_f32(const float * x, float * dst, const int k) {
+static __device__ __forceinline__ float op_neg(float x) {
+    return -x;
+}
+
+static __device__ __forceinline__ float op_step(float x) {
+    return x > 0.0f;
+}
+
+static __device__ __forceinline__ float op_gelu(float x) {
     const float GELU_COEF_A    = 0.044715f;
     const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i >= k) {
-        return;
-    }
-
-    float xi = x[i];
-    dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
+    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 }
 
-static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
+static __device__ __forceinline__ float op_gelu_quick(float x) {
     const float GELU_QUICK_COEF = -1.702f;
-    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
+
+    return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x)));
 }
 
-static __global__ void silu_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] / (1.0f + expf(-x[i]));
+static __device__ __forceinline__ float op_silu(float x) {
+    return x / (1.0f + expf(-x));
 }
 
-static __global__ void silu_back_f32(
-        const float * grad, const float * xf, float * dst, const int k) {
+static __device__ __forceinline__ float op_tanh(float x) {
+    return tanhf(x);
+}
+
+static __device__ __forceinline__ float op_relu(float x) {
+    return fmaxf(x, 0);
+}
+
+static __device__ __forceinline__ float op_sigmoid(float x) {
+    return 1.0f / (1.0f + expf(-x));
+}
+
+static __device__ __forceinline__ float op_hardsigmoid(float x) {
+    return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
+}
+
+static __device__ __forceinline__ float op_hardswish(float x) {
+    return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
+}
+
+static __device__ __forceinline__ float op_exp(float x) {
+    return expf(x);
+}
+
+static __device__ __forceinline__ float op_sqr(float x) {
+    return x * x;
+}
+
+static __device__ __forceinline__ float op_sqrt(float x) {
+    return sqrtf(x);
+}
+
+static __device__ __forceinline__ float op_sin(float x) {
+    return sinf(x);
+}
+
+static __device__ __forceinline__ float op_cos(float x) {
+    return cosf(x);
+}
+
+static __device__ __forceinline__ float op_log(float x) {
+    return logf(x);
+}
+
+template 
+static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
     }
 
-    const float xfi = xf[i];
-    const float s = 1.0f / (1.0f + expf(-xfi));
-    dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
+    dst[i] = (T)op((float)x[i]);
 }
 
-static __global__ void tanh_f32(const float * x, float * dst, int k) {
-    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
-    if (i >= k) {
-        return;
-    }
-    dst[i] = tanhf(x[i]);
-}
-
-static __global__ void relu_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = fmaxf(x[i], 0);
-}
-
-static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = 1.0f / (1.0f + expf(-x[i]));
-}
-
-static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
-}
-
-static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
-}
-
-static __global__ void exp_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = expf(x[i]);
-}
-
-static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
-    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
-    if (i >= k) {
-        return;
-    }
-    dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
-}
-
-static __global__ void sqr_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] * x[i];
-}
-
-static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sqrtf(x[i]);
-}
-
-static __global__ void sin_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sinf(x[i]);
-}
-
-static __global__ void cos_f32(const float * x, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = cosf(x[i]);
-}
-
-static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template 
+static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
-    neg_f32<<>>(x, dst, k);
+    unary_op_kernel<<>>(x, dst, k);
 }
 
-static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
-    step_f32<<>>(x, dst, k);
+template 
+void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
+
+    if (src0->type == GGML_TYPE_F16) {
+        unary_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        unary_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
-static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
-    gelu_f32<<>>(x, dst, k);
+void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
 }
 
-static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
-    gelu_quick_f32<<>>(x, dst, k);
-}
-
-static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
-    silu_f32<<>>(x, dst, k);
-}
-
-static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
-    silu_back_f32<<>>(grad, x, dst, k);
-}
-
-static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
-    tanh_f32<<>>(x, dst, k);
-}
-
-static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
-    relu_f32<<>>(x, dst, k);
-}
-
-static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
-    sigmoid_f32<<>>(x, dst, k);
-}
-
-static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
-    hardsigmoid_f32<<>>(x, dst, k);
-}
-
-static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
-    hardswish_f32<<>>(x, dst, k);
-}
-
-static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
-    exp_f32<<>>(x, dst, k);
-}
-
-static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
-    leaky_relu_f32<<>>(x, dst, k, negative_slope);
-}
-
-static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
-    sqr_f32<<>>(x, dst, k);
-}
-
-static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
-    sqrt_f32<<>>(x, dst, k);
-}
-
-static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
-    sin_f32<<>>(x, dst, k);
-}
-
-static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
-    cos_f32<<>>(x, dst, k);
+void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
 }
 
 void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    ggml_cuda_op_unary(ctx, dst);
 }
 
 void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    ggml_cuda_op_unary(ctx, dst);
 }
 
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
+    ggml_cuda_op_unary(ctx, dst);
+}
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
 }
 
 void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
+    ggml_cuda_op_unary(ctx, dst);
+}
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
+void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
 
-    silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
+
+void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
+
+void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
+
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
+
+void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
+
+void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
+
+/* silu_back */
+
+static __device__ __forceinline__ float op_silu_back(float grad, float x) {
+    const float s = 1.0f / (1.0f + expf(-x));
+    return grad * s * (1.0f + x * (1.0f - s));
+}
+
+template 
+static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]);
+}
+
+template 
+static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+    silu_back_kernel<<>>(grad, x, dst, k);
 }
 
 void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -314,179 +222,58 @@ void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
-void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
+/* leaky relu */
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) {
+    return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope;
 }
 
-void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
+template 
+static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) {
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
+    if (i >= k) {
+        return;
+    }
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    dst[i] = (T)op_leaky_relu((float)x[i], negative_slope);
 }
 
-void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
-}
-
-void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
-}
-
-void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
-}
-
-void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
-}
-
-void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+template 
+static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+    leaky_relu_kernel<<>>(x, dst, k, negative_slope);
 }
 
 void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
     float negative_slope;
     memcpy(&negative_slope, dst->op_params, sizeof(float));
 
-    leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
-}
-
-void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
-}
-
-void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
-}
-
-void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
-}
-
-void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream);
+    } else {
+        leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
+    }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh
index e7f62643a..940a1feed 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh
@@ -16,6 +16,10 @@
 #define CUDA_SIN_BLOCK_SIZE 256
 #define CUDA_COS_BLOCK_SIZE 256
 
+void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -49,3 +53,5 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/upscale.cu b/ml/backend/ggml/ggml/src/ggml-cuda/upscale.cu
index cf513c3ad..524e97957 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/upscale.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/upscale.cu
@@ -19,7 +19,7 @@ static __global__ void upscale_f32(const float * x, float * dst,
     int i02 = i12 / sf2;
     int i03 = i13 / sf3;
 
-    dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
+    dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
 }
 
 static void upscale_f32_cuda(const float * x, float * dst,
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
index a62544b5a..420b41b8d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
@@ -20,6 +20,7 @@
 #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
 #define CUBLAS_TF32_TENSOR_OP_MATH 0
 #define CUDA_R_16F  HIPBLAS_R_16F
+#define CUDA_R_16BF HIPBLAS_R_16B
 #define CUDA_R_32F  HIPBLAS_R_32F
 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
 #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
@@ -112,7 +113,7 @@
 #define cudaGraphExecDestroy hipGraphExecDestroy
 #define cudaGraphLaunch hipGraphLaunch
 #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
-#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
+#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
 #define cudaGraphNodeType hipGraphNodeType
 #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
 #define cudaGraphInstantiate hipGraphInstantiate
@@ -129,6 +130,7 @@
 #define cudaGraph_t hipGraph_t
 #define cudaStream_t hipStream_t
 #define cudaSuccess hipSuccess
+#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
 #define __trap() do { abort(); __builtin_unreachable(); } while(0)
 #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
 #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
@@ -150,7 +152,7 @@
 #define CDNA
 #endif
 
-#if defined(__gfx1200__) || defined(__gfx1201__)
+#if defined(__GFX12__)
 #define RDNA4
 #endif
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h
index 6cc1b69ee..937779a90 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h
@@ -15,6 +15,7 @@
 #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
 #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
 #define CUDA_R_16F  MUSA_R_16F
+#define CUDA_R_16BF MUSA_R_16BF
 #define CUDA_R_32F  MUSA_R_32F
 #define cublasComputeType_t cudaDataType_t
 #define cublasCreate mublasCreate
@@ -119,7 +120,7 @@
 #define cudaGraphExecDestroy musaGraphExecDestroy
 #define cudaGraphExec_t musaGraphExec_t
 #define cudaGraphExecUpdate musaGraphExecUpdate
-#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
+#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
 #define cudaGraphGetNodes musaGraphGetNodes
 #define cudaGraphInstantiate musaGraphInstantiate
 #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@@ -132,6 +133,8 @@
 #define cudaGraph_t musaGraph_t
 #define cudaKernelNodeParams musaKernelNodeParams
 #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
+#define cudaStreamBeginCapture musaStreamBeginCapture
 #define cudaStreamEndCapture musaStreamEndCapture
+#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
 
 typedef mt_bfloat16 nv_bfloat16;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/wkv.cu b/ml/backend/ggml/ggml/src/ggml-cuda/wkv.cu
new file mode 100644
index 000000000..d2fced705
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/wkv.cu
@@ -0,0 +1,199 @@
+#include "common.cuh"
+#include "wkv.cuh"
+
+template 
+static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = block_size;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    __syncthreads();
+    _tf[tid] = tf[head_i * head_size + tid];
+    __syncthreads();
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        __syncthreads();
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4& k = (float4&)(_k[j]);
+            const float4& r = (float4&)(_r[j]);
+            const float4& tf = (float4&)(_tf[j]);
+            const float4& td = (float4&)(_td[j]);
+            float4& s = (float4&)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            y += r.x * (tf.x * kv.x + s.x);
+            y += r.y * (tf.y * kv.y + s.y);
+            y += r.z * (tf.z * kv.z + s.z);
+            y += r.w * (tf.w * kv.w + s.w);
+
+            s.x = s.x * td.x + kv.x;
+            s.y = s.y * td.y + kv.y;
+            s.z = s.z * td.z + kv.z;
+            s.w = s.w * td.w + kv.w;
+        }
+        dst[t] = y;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+template 
+static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = block_size;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
+
+#ifndef GGML_USE_MUSA
+    #pragma unroll
+#endif
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
+    }
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _r[tid] = r[t];
+        _w[tid] = w[t];
+        _k[tid] = k[t];
+        _a[tid] = a[t];
+        _b[tid] = b[t];
+        __syncthreads();
+
+        float sa = 0;
+        #pragma unroll
+        for (int j = 0; j < head_size; j += 4)
+        {
+            const float4& a = (float4&)(_a[j]);
+            const float4& s = (float4&)(state[j]);
+            sa += a.x * s.x;
+            sa += a.y * s.y;
+            sa += a.z * s.z;
+            sa += a.w * s.w;
+        }
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4& r = (float4&)(_r[j]);
+            const float4& w = (float4&)(_w[j]);
+            const float4& k = (float4&)(_k[j]);
+            const float4& b = (float4&)(_b[j]);
+            float4& s = (float4&)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            s.x = s.x * w.x + kv.x + sa * b.x;
+            s.y = s.y * w.y + kv.y + sa * b.y;
+            s.z = s.z * w.z + kv.z + sa * b.z;
+            s.w = s.w * w.w + kv.w + sa * b.w;
+
+            y += s.x * r.x;
+            y += s.y * r.y;
+            y += s.z * r.z;
+            y += s.w * r.w;
+        }
+        dst[t] = y;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
+    }
+}
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * k_d  = (const float *)dst->src[0]->data;
+    const float * v_d  = (const float *)dst->src[1]->data;
+    const float * r_d  = (const float *)dst->src[2]->data;
+    const float * tf_d = (const float *)dst->src[3]->data;
+    const float * td_d = (const float *)dst->src[4]->data;
+    const float * s_d  = (const float *)dst->src[5]->data;
+
+    const int64_t B = dst->src[5]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
+
+    if (C / H == CUDA_WKV_BLOCK_SIZE) {
+        rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+    } else {
+        rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+    }
+}
+
+void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * r_d = (const float *)dst->src[0]->data;
+    const float * w_d = (const float *)dst->src[1]->data;
+    const float * k_d = (const float *)dst->src[2]->data;
+    const float * v_d = (const float *)dst->src[3]->data;
+    const float * a_d = (const float *)dst->src[4]->data;
+    const float * b_d = (const float *)dst->src[5]->data;
+    const float * s_d = (const float *)dst->src[6]->data;
+
+    const int64_t B = dst->src[6]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
+
+    if (C / H == CUDA_WKV_BLOCK_SIZE) {
+        rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
+    } else {
+        rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/wkv.cuh
similarity index 62%
rename from ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cuh
rename to ml/backend/ggml/ggml/src/ggml-cuda/wkv.cuh
index a7124ee51..9623dd7f8 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/wkv.cuh
@@ -3,3 +3,5 @@
 #define CUDA_WKV_BLOCK_SIZE 64
 
 void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu b/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu
deleted file mode 100644
index bbdafbee5..000000000
--- a/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu
+++ /dev/null
@@ -1,89 +0,0 @@
-#include "common.cuh"
-#include "wkv6.cuh"
-
-static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
-    const int tid = threadIdx.x;
-    const int bid = blockIdx.x;
-
-    const int head_size = CUDA_WKV_BLOCK_SIZE;
-    const int batch_i = bid / H;
-    const int head_i = bid % H;
-    const int state_size = C * head_size;
-    const int n_seq_tokens = T / B;
-
-    float state[head_size];
-    __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
-
-    #pragma unroll
-    for (int i = 0; i < head_size; i++) {
-        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
-    }
-
-    __syncthreads();
-    _tf[tid] = tf[head_i * head_size + tid];
-    __syncthreads();
-
-    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
-        __syncthreads();
-        _k[tid] = k[t];
-        _r[tid] = r[t];
-        _td[tid] = td[t];
-        __syncthreads();
-
-        const float _v = v[t];
-        float y = 0;
-        for (int j = 0; j < head_size; j += 4) {
-            const float4& k = (float4&)(_k[j]);
-            const float4& r = (float4&)(_r[j]);
-            const float4& tf = (float4&)(_tf[j]);
-            const float4& td = (float4&)(_td[j]);
-            float4& s = (float4&)(state[j]);
-            float4 kv;
-
-            kv.x = k.x * _v;
-            kv.y = k.y * _v;
-            kv.z = k.z * _v;
-            kv.w = k.w * _v;
-
-            y += r.x * (tf.x * kv.x + s.x);
-            y += r.y * (tf.y * kv.y + s.y);
-            y += r.z * (tf.z * kv.z + s.z);
-            y += r.w * (tf.w * kv.w + s.w);
-
-            s.x = s.x * td.x + kv.x;
-            s.y = s.y * td.y + kv.y;
-            s.z = s.z * td.z + kv.z;
-            s.w = s.w * td.w + kv.w;
-        }
-        dst[t] = y;
-    }
-
-    #pragma unroll
-    for (int i = 0; i < head_size; i++) {
-        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
-    }
-}
-
-void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const float * k_d  = (const float *)dst->src[0]->data;
-    const float * v_d  = (const float *)dst->src[1]->data;
-    const float * r_d  = (const float *)dst->src[2]->data;
-    const float * tf_d = (const float *)dst->src[3]->data;
-    const float * td_d = (const float *)dst->src[4]->data;
-    const float * s_d  = (const float *)dst->src[5]->data;
-
-    const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[2];
-    const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[1];
-
-    float * dst_d = (float *)dst->data;
-
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
-    GGML_ASSERT(C % H == 0);
-    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
-
-    rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
-}
diff --git a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
index 4a0384dd4..e3762649f 100644
--- a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
@@ -39,6 +39,12 @@ endif()
 find_package(hip     REQUIRED)
 find_package(hipblas REQUIRED)
 find_package(rocblas REQUIRED)
+if (GGML_HIP_ROCWMMA_FATTN)
+    CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
+    if (NOT ${FOUND_ROCWMMA})
+        message(FATAL_ERROR "rocwmma has not been found")
+    endif()
+endif()
 
 if (${hip_VERSION} VERSION_LESS 5.5)
     message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
@@ -107,6 +113,10 @@ if (GGML_HIP_NO_VMM)
     add_compile_definitions(GGML_HIP_NO_VMM)
 endif()
 
+if (GGML_HIP_ROCWMMA_FATTN)
+    add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
+endif()
+
 if (NOT GGML_CUDA_FA)
     add_compile_definitions(GGML_CUDA_NO_FA)
 endif()
diff --git a/ml/backend/ggml/ggml/src/ggml-impl.h b/ml/backend/ggml/ggml/src/ggml-impl.h
index 1fbcbd045..a19cfb14e 100644
--- a/ml/backend/ggml/ggml/src/ggml-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-impl.h
@@ -148,8 +148,14 @@ struct ggml_map_custom2_op_params {
 
 struct ggml_map_custom3_op_params {
     ggml_custom3_op_t fun;
-    int n_tasks;
-    void * userdata;
+    int               n_tasks;
+    void            * userdata;
+};
+
+struct ggml_custom_op_params {
+    ggml_custom_op_t fun;
+    int              n_tasks;
+    void           * userdata;
 };
 
 // bitset
@@ -311,29 +317,28 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
 
 // FP16 to FP32 conversion
 
-#if defined(__ARM_NEON)
-    #if defined(_MSC_VER) || (defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11)
-        typedef uint16_t ggml_fp16_internal_t;
-    #else
-        typedef __fp16 ggml_fp16_internal_t;
-    #endif
-#endif
-
-#if defined(__ARM_NEON) && !defined(_MSC_VER) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11)
+// 16-bit float
+// on Arm, we use __fp16
+// on x86, we use uint16_t
+//
+// for old CUDA compilers (<= 11), we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/10616
+// for     MUSA compilers        , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843
+//
+#if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
     #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
     #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
 
     #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
 
     static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
-        ggml_fp16_internal_t tmp;
+        __fp16 tmp;
         memcpy(&tmp, &h, sizeof(ggml_fp16_t));
         return (float)tmp;
     }
 
     static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
         ggml_fp16_t res;
-        ggml_fp16_internal_t tmp = f;
+        __fp16 tmp = f;
         memcpy(&res, &tmp, sizeof(ggml_fp16_t));
         return res;
     }
@@ -357,8 +362,8 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
     #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
 
     static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
-        register float f;
-        register double d;
+        float f;
+        double d;
         __asm__(
             "mtfprd %0,%2\n"
             "xscvhpdp %0,%0\n"
@@ -370,8 +375,8 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
     }
 
     static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
-        register double d;
-        register ggml_fp16_t r;
+        double d;
+        ggml_fp16_t r;
         __asm__( /* xscvdphp can work on double or single precision */
             "xscvdphp %0,%2\n"
             "mffprd %1,%0\n" :
@@ -381,6 +386,35 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
         return r;
     }
 
+#elif defined(__riscv) && defined(GGML_RV_ZFH)
+
+    static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+        float f;
+        __asm__(
+            "fmv.h.x %[f], %[h]\n\t"
+            "fcvt.s.h %[f], %[f]"
+            : [f] "=&f" (f)
+            : [h] "r" (h)
+        );
+        return f;
+    }
+
+    static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+        ggml_fp16_t res;
+        __asm__(
+            "fcvt.h.s %[f], %[f]\n\t"
+            "fmv.x.h %[h], %[f]"
+            : [h] "=&r" (res)
+            : [f] "f" (f)
+        );
+        return res;
+    }
+
+    #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+    #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+    #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
+    #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
 #else
 
     // FP16 <-> FP32
@@ -456,7 +490,7 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
     #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
     #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
 
-#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
+#endif // defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
 
 // precomputed f32 table for f16 (256 KB)
 // defined in ggml.c, initialized in ggml_init()
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt
index 89fcde2fa..e22232780 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt
@@ -27,12 +27,12 @@ configure_file(../ggml-common.h  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
 configure_file(ggml-metal.metal  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal  COPYONLY)
 configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
 
+set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
 if (GGML_METAL_EMBED_LIBRARY)
     enable_language(ASM)
 
     add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
 
-    set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
     set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
     set(METALLIB_IMPL   "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
 
@@ -88,12 +88,11 @@ else()
 
     add_custom_command(
         OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
-        COMMAND xcrun -sdk macosx metal    ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
-        COMMAND xcrun -sdk macosx metallib                ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air   -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
-        COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
+        COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
+            xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
         COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
         COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
-        DEPENDS ggml-metal.metal ggml-common.h
+        DEPENDS ggml-metal.metal ${METALLIB_COMMON}
         COMMENT "Compiling Metal kernels"
         )
 
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
index a2f599ce5..eaf6a23d8 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
@@ -162,6 +162,12 @@ typedef sycl::half2 ggml_half2;
 
 #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
 
+#ifdef _MSC_VER
+#define GGML_EXTENSION
+#else // _MSC_VER
+#define GGML_EXTENSION __extension__
+#endif // _MSC_VER
+
 #define QK4_0 32
 typedef struct {
     ggml_half d;           // delta
@@ -171,7 +177,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 b
 
 #define QK4_1 32
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d; // delta
             ggml_half m; // min
@@ -192,7 +198,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0
 
 #define QK5_1 32
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d; // delta
             ggml_half m; // min
@@ -213,7 +219,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block
 
 #define QK8_1 32
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d; // delta
             ggml_half s; // d * sum(qs[i])
@@ -254,7 +260,7 @@ static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
     uint8_t qs[QK_K/4];      // quants
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d;    // super-block scale for quantized scales
             ggml_half dmin; // super-block scale for quantized mins
@@ -281,7 +287,7 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12
 // weight is represented as x = a * q + b
 // Effectively 4.5 bits per weight
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d;    // super-block scale for quantized scales
             ggml_half dmin; // super-block scale for quantized mins
@@ -298,7 +304,7 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2,
 // weight is represented as x = a * q + b
 // Effectively 5.5 bits per weight
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d;    // super-block scale for quantized scales
             ggml_half dmin; // super-block scale for quantized mins
@@ -1854,12 +1860,75 @@ GGML_TABLE_END()
 #endif // GGML_COMMON_IMPL
 #endif // GGML_COMMON_IMPL
 #else
-// TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift
-#include "../ggml-common.h"
+#include "ggml-common.h"
 #endif
 #ifndef GGML_METAL_IMPL
 #define GGML_METAL_IMPL
 
+// kernel parameters for mat-vec threadgroups
+//
+// N_R0: number of src0 rows to process per simdgroup
+// N_SG: number of simdgroups per threadgroup
+//
+// TODO: for optimal performance, become function of the device and work size
+
+#define N_R0_Q4_0 4
+#define N_SG_Q4_0 2
+
+#define N_R0_Q4_1 4
+#define N_SG_Q4_1 2
+
+#define N_R0_Q5_0 4
+#define N_SG_Q5_0 2
+
+#define N_R0_Q5_1 4
+#define N_SG_Q5_1 2
+
+#define N_R0_Q8_0 4
+#define N_SG_Q8_0 2
+
+#define N_R0_Q2_K 4
+#define N_SG_Q2_K 2
+
+#define N_R0_Q3_K 2
+#define N_SG_Q3_K 2
+
+#define N_R0_Q4_K 4
+#define N_SG_Q4_K 2
+
+#define N_R0_Q5_K 2
+#define N_SG_Q5_K 2
+
+#define N_R0_Q6_K 1
+#define N_SG_Q6_K 2
+
+#define N_R0_IQ1_S 4
+#define N_SG_IQ1_S 2
+
+#define N_R0_IQ1_M 4
+#define N_SG_IQ1_M 2
+
+#define N_R0_IQ2_XXS 4
+#define N_SG_IQ2_XXS 2
+
+#define N_R0_IQ2_XS 4
+#define N_SG_IQ2_XS 2
+
+#define N_R0_IQ2_S 4
+#define N_SG_IQ2_S 2
+
+#define N_R0_IQ3_XXS 4
+#define N_SG_IQ3_XXS 2
+
+#define N_R0_IQ3_S 4
+#define N_SG_IQ3_S 2
+
+#define N_R0_IQ4_NL 2
+#define N_SG_IQ4_NL 2
+
+#define N_R0_IQ4_XS 2
+#define N_SG_IQ4_XS 2
+
 // kernel argument structs
 //
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -2014,9 +2083,12 @@ typedef struct {
     int32_t  ne11;
     int32_t  ne_12_2; // assume K and V are same shape
     int32_t  ne_12_3;
-    uint64_t nb_12_1;
-    uint64_t nb_12_2;
-    uint64_t nb_12_3;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb23;
     uint64_t nb31;
     int32_t  ne1;
     int32_t  ne2;
@@ -2144,6 +2216,248 @@ typedef struct {
     float    eps;
 } ggml_metal_kargs_rms_norm;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_l2_norm;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  n_groups;
+    float    eps;
+} ggml_metal_kargs_group_norm;
+
+typedef struct {
+    int32_t  IC;
+    int32_t  IL;
+    int32_t  K;
+    int32_t  s0;
+    uint64_t nb0;
+    uint64_t nb1;
+} ggml_metal_kargs_conv_transpose_1d;
+
+typedef struct {
+    uint64_t  ofs0;
+    uint64_t  ofs1;
+    int32_t  IW;
+    int32_t  IH;
+    int32_t  CHW;
+    int32_t  s0;
+    int32_t  s1;
+    int32_t  p0;
+    int32_t  p1;
+    int32_t  d0;
+    int32_t  d1;
+    int32_t  N;
+    int32_t  KH;
+    int32_t  KW;
+    int32_t  KHW; // KH * KW, pre-computed on CPU to save GPU resources
+} ggml_metal_kargs_im2col;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne10;
+    int64_t  ne11;
+    int64_t  ne12;
+    int64_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_sum_rows;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    float    scale;
+    float    max_bias;
+    float    m0;
+    float    m1;
+    uint32_t n_head_log2;
+} ggml_metal_kargs_soft_max;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int      n_past;
+} ggml_metal_kargs_diag_mask_inf;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int64_t  ne10;
+    int64_t  ne11;
+    uint64_t nb10;
+    uint64_t nb11;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+} ggml_metal_kargs_ssm_conv;
+
+typedef struct {
+    int64_t  d_state;
+    int64_t  d_inner;
+    int64_t  n_seq_tokens;
+    int64_t  n_seqs;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb20;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb30;
+    uint64_t nb31;
+    uint64_t nb40;
+    uint64_t nb41;
+    uint64_t nb42;
+    uint64_t nb50;
+    uint64_t nb51;
+    uint64_t nb52;
+} ggml_metal_kargs_ssm_scan;
+
+typedef struct {
+    int64_t  ne00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int64_t  ne10;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb1;
+    uint64_t nb2;
+} ggml_metal_kargs_get_rows;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    float    sf0;
+    float    sf1;
+    float    sf2;
+    float    sf3;
+} ggml_metal_kargs_upscale;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_pad;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  p0;
+    int32_t  p1;
+} ggml_metal_kargs_pad_reflect_1d;
+
+typedef struct {
+    uint64_t nb1;
+    int      dim;
+    int      max_period;
+} ggml_metal_kargs_timestep_embedding;
+
+typedef struct {
+    float    slope;
+} ggml_metal_kargs_leaky_relu;
+
+typedef struct {
+    int64_t  ncols;
+    int64_t  ncols_pad;
+} ggml_metal_kargs_argsort;
+
+typedef struct {
+    int64_t  ne0;
+    float    start;
+    float    step;
+} ggml_metal_kargs_arange;
+
+typedef struct {
+    int32_t  k0;
+    int32_t  k1;
+    int32_t  s0;
+    int32_t  s1;
+    int32_t  p0;
+    int32_t  p1;
+    int64_t  IH;
+    int64_t  IW;
+    int64_t  OH;
+    int64_t  OW;
+    int64_t  parallel_elements;
+} ggml_metal_kargs_pool_2d;
+
 #endif // GGML_METAL_IMPL
 
 #include 
@@ -2187,7 +2501,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
 
 template 
 void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
-    reg = (type4)(*(src + il));
+    reg = (type4)(*(src));
 }
 
 #if defined(GGML_METAL_USE_BF16)
@@ -2195,6 +2509,11 @@ template 
 void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
     reg = (type4x4)(*src);
 }
+
+template 
+void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
+    reg = (type4)(*(src));
+}
 #endif
 
 template 
@@ -3093,45 +3412,22 @@ kernel void kernel_neg(
 kernel void kernel_sum_rows(
         device const float * src0,
         device       float * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
+        constant ggml_metal_kargs_sum_rows & args,
         uint3 tpig[[thread_position_in_grid]]) {
     int64_t i3 = tpig.z;
     int64_t i2 = tpig.y;
     int64_t i1 = tpig.x;
 
-    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
         return;
     }
 
-    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
-    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+    device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
+    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
 
     float row_sum = 0;
 
-    for (int64_t i0 = 0; i0 < ne00; i0++) {
+    for (int64_t i0 = 0; i0 < args.ne00; i0++) {
         row_sum += src_row[i0];
     }
 
@@ -3143,36 +3439,29 @@ kernel void kernel_soft_max(
         device const  char * src0,
         device const  char * src1,
         device        char * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint32_t & n_head_log2,
+        constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (ne02*ne01);
-    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
-    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+    const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
+    const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
+    const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
 
-    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
-    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
+    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*args.ne00 : nullptr;
+    device       float * pdst  = (device       float *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
 
     float slope = 1.0f;
 
     // ALiBi
-    if (max_bias > 0.0f) {
+    if (args.max_bias > 0.0f) {
         const int64_t h = i02;
 
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+        const float base = h < args.n_head_log2 ? args.m0 : args.m1;
+        const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
         slope = pow(base, exp);
     }
@@ -3180,8 +3469,8 @@ kernel void kernel_soft_max(
     // parallel max
     float lmax = -INFINITY;
 
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+        lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
 
     // find the max value in the block
@@ -3205,8 +3494,8 @@ kernel void kernel_soft_max(
 
     // parallel sum
     float lsum = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+        const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
         lsum += exp_psrc0;
         pdst[i00] = exp_psrc0;
     }
@@ -3236,7 +3525,7 @@ kernel void kernel_soft_max(
 
     const float inv_sum = 1.0f/sum;
 
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
         pdst[i00] *= inv_sum;
     }
 }
@@ -3246,35 +3535,28 @@ kernel void kernel_soft_max_4(
         device const  char * src0,
         device const  char * src1,
         device        char * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint32_t & n_head_log2,
+        constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (ne02*ne01);
-    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
-    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+    const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
+    const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
+    const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
 
-    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
-    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
-    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+    device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
+    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*args.ne00/4 : nullptr;
+    device       float4 * pdst4 = (device       float4 *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
 
     float slope = 1.0f;
 
-    if (max_bias > 0.0f) {
+    if (args.max_bias > 0.0f) {
         const int64_t h = i02;
 
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+        const float base = h < args.n_head_log2 ? args.m0 : args.m1;
+        const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
         slope = pow(base, exp);
     }
@@ -3282,8 +3564,8 @@ kernel void kernel_soft_max_4(
     // parallel max
     float4 lmax4 = -INFINITY;
 
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+        lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -3308,8 +3590,8 @@ kernel void kernel_soft_max_4(
 
     // parallel sum
     float4 lsum4 = 0.0f;
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+        const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
@@ -3341,7 +3623,7 @@ kernel void kernel_soft_max_4(
 
     const float inv_sum = 1.0f/sum;
 
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
         pdst4[i00] *= inv_sum;
     }
 }
@@ -3357,27 +3639,23 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
 kernel void kernel_diag_mask_inf(
         device const float * src0,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant       int & n_past,
+        constant ggml_metal_kargs_diag_mask_inf & args,
         uint3 tpig[[thread_position_in_grid]]) {
     const int64_t i02 = tpig[2];
     const int64_t i01 = tpig[1];
     const int64_t i00 = tpig[0];
 
-    if (i00 > n_past + i01) {
-        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    if (i00 > args.n_past + i01) {
+        dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY;
     } else {
-        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+        dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00];
     }
 }
 
 kernel void kernel_diag_mask_inf_8(
         device const float4 * src0,
         device       float4 * dst,
-        constant    int64_t & ne00,
-        constant    int64_t & ne01,
-        constant        int & n_past,
+        constant ggml_metal_kargs_diag_mask_inf & args,
         uint3 tpig[[thread_position_in_grid]]) {
 
     const int64_t i = 2*tpig[0];
@@ -3385,42 +3663,26 @@ kernel void kernel_diag_mask_inf_8(
     dst[i+0] = src0[i+0];
     dst[i+1] = src0[i+1];
     int64_t i4 = 4*i;
-    const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
-    const int64_t i01 = i4/(ne00);      i4 -= i01*ne00;
+    const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01;
+    const int64_t i01 = i4/(args.ne00);      i4 -= i01*args.ne00;
     const int64_t i00 = i4;
     for (int k = 3; k >= 0; --k) {
-        if (i00 + 4 + k <= n_past + i01) {
+        if (i00 + 4 + k <= args.n_past + i01) {
             break;
         }
         dst[i+1][k] = -INFINITY;
-        if (i00 + k > n_past + i01) {
+        if (i00 + k > args.n_past + i01) {
             dst[i][k] = -INFINITY;
         }
     }
 }
 
 // ref: ggml.c:ggml_compute_forward_ssm_conv_f32
-// TODO: optimize
 kernel void kernel_ssm_conv_f32(
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_ssm_conv & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -3428,15 +3690,15 @@ kernel void kernel_ssm_conv_f32(
     const int64_t i2 = tgpig.y;
     const int64_t i3 = tgpig.z;
 
-    const int64_t nc  = ne10;
-  //const int64_t ncs = ne00;
-  //const int64_t nr  = ne01;
-  //const int64_t n_t = ne1;
-  //const int64_t n_s = ne2;
+    const int64_t nc  = args.ne10;
+  //const int64_t ncs = args.ne00;
+  //const int64_t nr  = args.ne01;
+  //const int64_t n_t = args.ne1;
+  //const int64_t n_s = args.ne2;
 
-    device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
-    device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
-    device       float * x = (device       float *) ((device       char *) dst  + ir*nb0  + i2*nb1  + i3*nb2);
+    device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
+    device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
+    device       float * x = (device       float *) ((device       char *) dst  + ir*args.nb0  + i2*args.nb1  + i3*args.nb2);
 
     float sumf = 0.0f;
 
@@ -3448,7 +3710,6 @@ kernel void kernel_ssm_conv_f32(
 }
 
 // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
-// TODO: optimize
 kernel void kernel_ssm_scan_f32(
         device const void * src0,
         device const void * src1,
@@ -3457,48 +3718,27 @@ kernel void kernel_ssm_scan_f32(
         device const void * src4,
         device const void * src5,
         device      float * dst,
-        constant  int64_t & d_state,
-        constant  int64_t & d_inner,
-        constant  int64_t & n_seq_tokens,
-        constant  int64_t & n_seqs,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant uint64_t & nb20,
-        constant uint64_t & nb21,
-        constant uint64_t & nb22,
-        constant uint64_t & nb30,
-        constant uint64_t & nb31,
-        constant uint64_t & nb40,
-        constant uint64_t & nb41,
-        constant uint64_t & nb42,
-        constant uint64_t & nb50,
-        constant uint64_t & nb51,
-        constant uint64_t & nb52,
+        constant ggml_metal_kargs_ssm_scan & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
     const int64_t ir = tgpig.x;
     const int64_t i3 = tgpig.y;
 
-    const int64_t nc  = d_state;
-  //const int64_t nr  = d_inner;
-    const int64_t n_t = n_seq_tokens;
-  //const int64_t n_s = n_seqs;
+    const int64_t nc  = args.d_state;
+    // const int64_t nr  = args.d_inner;
+    const int64_t n_t = args.n_seq_tokens;
+    // const int64_t n_s = args.n_seqs;
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
-        device const float * x  = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
-        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
-        device const float * A  = (device const float *) ((device const char *) src3 + ir*nb31);
-        device const float * B  = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
-        device const float * C  = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
-        device       float * y  = (device       float *) ((device       char *) dst  + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
-        device       float * s  = (device       float *) ((device       char *) dst  + ir*nb01 + i3*nb02 +    nb13);
+        device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
+        device const float * x  = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
+        device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
+        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31);
+        device const float * B  = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
+        device const float * C  = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
+        device       float * y  = (device       float *) ((device       char *) dst  + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
+        device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb01 + i3*args.nb02 +    args.nb13);
 
         if (i2 > 0) {
             s0 = s;
@@ -3520,6 +3760,184 @@ kernel void kernel_ssm_scan_f32(
     }
 }
 
+kernel void kernel_rwkv_wkv6_f32(
+    device const float * k,
+    device const float * v,
+    device const float * r,
+    device const float * tf,
+    device const float * td,
+    device const float * state_in,
+    device       float * dst,
+    constant    uint & B,
+    constant    uint & T,
+    constant    uint & C,
+    constant    uint & H,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]])  {
+
+    const uint head_size = 64; // TODO: support head_size = 128
+    const uint batch_id = tgpig.x / H;
+    const uint head_id = tgpig.x % H;
+    const uint tid = tpitg.x;
+
+    if (batch_id >= B || head_id >= H) {
+        return;
+    }
+
+    const uint state_size = C * head_size;
+    const uint n_seq_tokens = T / B;
+
+    threadgroup float _k[head_size];
+    threadgroup float _r[head_size];
+    threadgroup float _tf[head_size];
+    threadgroup float _td[head_size];
+
+    float state[head_size];
+
+    for (uint i = 0; i < head_size; i++) {
+        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+                          + i * head_size + tid];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    _tf[tid] = tf[head_id * head_size + tid];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+    for (uint t = start_t; t < end_t; t += C) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        const float v_val = v[t];
+        float y = 0.0;
+
+        for (uint j = 0; j < head_size; j += 4) {
+            float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+            float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+            float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
+            float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
+            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+            float4 kv = k_vec * v_val;
+
+            float4 temp = tf_vec * kv + s_vec;
+            y += dot(r_vec, temp);
+
+            s_vec = s_vec * td_vec + kv;
+            state[j]   = s_vec[0];
+            state[j+1] = s_vec[1];
+            state[j+2] = s_vec[2];
+            state[j+3] = s_vec[3];
+        }
+
+        dst[t] = y;
+    }
+
+    for (uint i = 0; i < head_size; i++) {
+        dst[T * C + batch_id * state_size + head_id * head_size * head_size
+            + i * head_size + tid] = state[i];
+    }
+}
+
+kernel void kernel_rwkv_wkv7_f32(
+    device const float * r,
+    device const float * w,
+    device const float * k,
+    device const float * v,
+    device const float * a,
+    device const float * b,
+    device const float * state_in,
+    device       float * dst,
+    constant    uint & B,
+    constant    uint & T,
+    constant    uint & C,
+    constant    uint & H,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]])  {
+
+    const uint head_size = 64; // TODO: support head_size = 128
+    const uint batch_id = tgpig.x / H;
+    const uint head_id = tgpig.x % H;
+    const uint tid = tpitg.x;
+
+    if (batch_id >= B || head_id >= H) {
+        return;
+    }
+
+    const uint state_size = C * head_size;
+    const uint n_seq_tokens = T / B;
+
+    threadgroup float _r[head_size];
+    threadgroup float _w[head_size];
+    threadgroup float _k[head_size];
+    threadgroup float _a[head_size];
+    threadgroup float _b[head_size];
+
+    float state[head_size];
+
+    for (uint i = 0; i < head_size; i++) {
+        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+                          + tid * head_size + i];
+    }
+
+    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+    for (uint t = start_t; t < end_t; t += C) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        _r[tid] = r[t];
+        _w[tid] = w[t];
+        _k[tid] = k[t];
+        _a[tid] = a[t];
+        _b[tid] = b[t];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        const float v_val = v[t];
+        float y = 0.0, sa = 0.0;
+
+        float4 sa_vec(0.0);
+
+        for (uint j = 0; j < head_size; j += 4) {
+            float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
+            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+            sa_vec += a_vec * s_vec;
+        }
+        sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
+
+        for (uint j = 0; j < head_size; j += 4) {
+            float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+            float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
+            float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+            float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
+            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+            float4 kv = k_vec * v_val;
+
+            s_vec = s_vec * w_vec + kv + sa * b_vec;
+            y += dot(s_vec, r_vec);
+
+            state[j]   = s_vec[0];
+            state[j+1] = s_vec[1];
+            state[j+2] = s_vec[2];
+            state[j+3] = s_vec[3];
+        }
+
+        dst[t] = y;
+    }
+
+    for (uint i = 0; i < head_size; i++) {
+        dst[T * C + batch_id * state_size + head_id * head_size * head_size
+            + tid * head_size + i] = state[i];
+    }
+}
+
 kernel void kernel_argmax(
         device   const void * x,
         device      int32_t * dst,
@@ -3688,25 +4106,61 @@ kernel void kernel_rms_norm(
     }
 }
 
+kernel void kernel_l2_norm(
+        constant ggml_metal_kargs_l2_norm & args,
+        device const char * src0,
+        device       char * dst,
+        threadgroup float * shmem_f32 [[threadgroup(0)]],
+        uint   tgpig[[threadgroup_position_in_grid]],
+        ushort tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort   ntg[[threads_per_threadgroup]]) {
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = 0.0f;
+    }
+
+    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+    float sumf = 0.0f;
+
+    // parallel sum
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        sumf += dot(x[i00], x[i00]);
+    }
+    sumf = simd_sum(sumf);
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float scale = 1.0f/sqrt(max(sumf, args.eps));
+
+    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        y[i00] = x[i00] * scale;
+    }
+}
+
 kernel void kernel_group_norm(
         device const float * src0,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int32_t & n_groups,
-        constant     float & eps,
+        constant ggml_metal_kargs_group_norm & args,
         threadgroup float  * buf [[threadgroup(0)]],
         uint tgpig[[threadgroup_position_in_grid]],
         uint tpitg[[thread_position_in_threadgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    const int64_t ne = ne00*ne01*ne02;
-    const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+    const int64_t ne = args.ne00*args.ne01*args.ne02;
+    const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups);
 
     int start = tgpig * gs;
     int end   = start + gs;
@@ -3770,7 +4224,7 @@ kernel void kernel_group_norm(
     }
 
     const float variance = tmp / gs;
-    const float scale = 1.0f/sqrt(variance + eps);
+    const float scale = 1.0f/sqrt(variance + args.eps);
     for (int j = start; j < end; j += ntg) {
         dst[j] *= scale;
     }
@@ -3864,14 +4318,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
     return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
 }
 
-// putting them in the kernel cause a significant performance penalty
-#define N_DST 4        // each SIMD group works on 4 rows
-#define N_SIMDGROUP 2  // number of SIMD groups in a thread group
-//Note: This is a template, but strictly speaking it only applies to
-//      quantizations where the block size is 32. It also does not
-//      guard against the number of rows not being divisible by
-//      N_DST, so this is another explicit assumption of the implementation.
-template
+template
 void mul_vec_q_n_f32_impl(
         args_t args,
         device const char * src0,
@@ -3887,7 +4334,7 @@ void mul_vec_q_n_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -3899,15 +4346,15 @@ void mul_vec_q_n_f32_impl(
     device const float        * y = (device const float        *) (src1 + offset1);
 
     // pointers to src0 rows
-    device const block_q_type * ax[nr];
-    for (int row = 0; row < nr; ++row) {
+    device const block_q_type * ax[nr0];
+    for (int row = 0; row < nr0; ++row) {
         const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
         ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
     }
 
     float yl[16]; // src1 vector cache
-    float sumf[nr] = {0.f};
+    float sumf[nr0] = {0.f};
 
     const short ix = (tiisg/2);
     const short il = (tiisg%2)*8;
@@ -3919,7 +4366,7 @@ void mul_vec_q_n_f32_impl(
         float sumy[2] = { 0.f, 0.f };
 
 #pragma unroll
-        for (int i = 0; i < 8; i += 2) {
+        for (short i = 0; i < 8; i += 2) {
             sumy[0]  += yb[i +  0] + yb[i +  1];
             yl[i + 0] = yb[i +  0];
             yl[i + 1] = yb[i +  1]/256.f;
@@ -3930,7 +4377,7 @@ void mul_vec_q_n_f32_impl(
         }
 
 #pragma unroll
-        for (int row = 0; row < nr; row++) {
+        for (short row = 0; row < nr0; row++) {
             sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
         }
 
@@ -3939,7 +4386,7 @@ void mul_vec_q_n_f32_impl(
 
     device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
 
-    for (int row = 0; row < nr; ++row) {
+    for (int row = 0; row < nr0; ++row) {
         const float tot = simd_sum(sumf[row]);
 
         if (tiisg == 0 && first_row + row < args.ne01) {
@@ -3956,7 +4403,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -3967,7 +4414,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+     mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -3978,7 +4425,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -3989,12 +4436,12 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 #define NB_Q8_0 8
 
-template
+template
 void kernel_mul_mv_q8_0_f32_impl(
         args_t args,
         device const char * src0,
@@ -4004,16 +4451,13 @@ void kernel_mul_mv_q8_0_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
-    const int nr  = N_DST;
-    const int nsg = N_SIMDGROUP;
-    const int nw  = N_SIMDWIDTH;
-
     const int nb = args.ne00/QK8_0;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0*nsg + sgitg)*nr;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -4025,15 +4469,15 @@ void kernel_mul_mv_q8_0_f32_impl(
     device const float      * y = (device const float      *) (src1 + offset1);
 
     // pointers to src0 rows
-    device const block_q8_0 * ax[nr];
-    for (int row = 0; row < nr; ++row) {
+    device const block_q8_0 * ax[nr0];
+    for (int row = 0; row < nr0; ++row) {
         const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
         ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
     }
 
     float yl[NB_Q8_0];
-    float sumf[nr] = { 0.f };
+    float sumf[nr0] = { 0.f };
 
     const short ix = tiisg/4;
     const short il = tiisg%4;
@@ -4046,7 +4490,7 @@ void kernel_mul_mv_q8_0_f32_impl(
             yl[i] = yb[i];
         }
 
-        for (int row = 0; row < nr; row++) {
+        for (short row = 0; row < nr0; row++) {
             device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
             float sumq = 0.f;
             for (short iq = 0; iq < NB_Q8_0; ++iq) {
@@ -4060,7 +4504,7 @@ void kernel_mul_mv_q8_0_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < nr; ++row) {
+    for (int row = 0; row < nr0; ++row) {
         const float tot = simd_sum(sumf[row]);
 
         if (tiisg == 0 && first_row + row < args.ne01) {
@@ -4078,7 +4522,7 @@ kernel void kernel_mul_mv_q8_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 // mat-vec kernel processing in chunks of float4
@@ -4415,9 +4859,9 @@ void kernel_mul_mv_impl(
                 sumf += (T0) x[i] * (T1) y[i];
             }
 
-            float all_sum = simd_sum(sumf);
+            float sum_all = simd_sum(sumf);
             if (tiisg == 0) {
-                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
             }
         }
     } else {
@@ -4438,10 +4882,10 @@ void kernel_mul_mv_impl(
                 sumf += dot((float4) x4[i], (float4) y4[i]);
             }
 
-            float all_sum = simd_sum(sumf);
+            float sum_all = simd_sum(sumf);
             if (tiisg == 0) {
-                for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
-                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+                for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
             }
         }
     }
@@ -4503,9 +4947,9 @@ kernel void kernel_mul_mv_1row(
         for (int i = tiisg; i < args.ne00; i += 32) {
             sumf += (float) x[i] * (float) y[i];
         }
-        float all_sum = simd_sum(sumf);
+        float sum_all = simd_sum(sumf);
         if (tiisg == 0) {
-            dst_f32[r0] = all_sum;
+            dst_f32[r0] = sum_all;
         }
     } else {
         device const T4     * x4 = (device const T4     *) x;
@@ -4515,11 +4959,11 @@ kernel void kernel_mul_mv_1row(
             sumf += dot((float4) x4[i], y4[i]);
         }
 
-        float all_sum = simd_sum(sumf);
+        float sum_all = simd_sum(sumf);
 
         if (tiisg == 0) {
-            for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
-            dst_f32[r0] = all_sum;
+            for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
+            dst_f32[r0] = sum_all;
         }
     }
 }
@@ -4564,9 +5008,9 @@ kernel void kernel_mul_mv_l4(
             sumf += dot((float4) x4[i], y4[i]);
         }
 
-        float all_sum = simd_sum(sumf);
+        float sum_all = simd_sum(sumf);
         if (tiisg == 0) {
-            dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+            dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
         }
     }
 }
@@ -4734,17 +5178,7 @@ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_
 typedef void (im2col_t)(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
@@ -4754,17 +5188,7 @@ template 
 kernel void kernel_im2col(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
@@ -4785,17 +5209,17 @@ kernel void kernel_im2col(
     const int64_t ioh = tgpig[1];
     const int64_t iow = tgpig[2];
 
-    const int64_t iiw = iow*s0 + ikw*d0 - p0;
-    const int64_t iih = ioh*s1 + ikh*d1 - p1;
+    const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
+    const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
 
-    const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
+    const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
 
     device T * pdst = (device T *) (dst);
 
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
         pdst[offset_dst] = 0.0f;
     } else {
-        const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
+        const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
         pdst[offset_dst] = x[offset_src];
     }
 }
@@ -4806,20 +5230,7 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col;
 typedef void (im2col_ext_t)(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
-        constant   int32_t & N,
-        constant   int32_t & KH,
-        constant   int32_t & KW,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
@@ -4829,53 +5240,40 @@ template 
 kernel void kernel_im2col_ext(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
-        constant   int32_t & N,
-        constant   int32_t & KH,
-        constant   int32_t & KW,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
-    const int64_t KHW = KH * KW;             // KHW == ntg[1] * ntg[2], KW == ntg[2]
+    const int64_t KHW = (int64_t)args.KHW;
 
-    const int64_t d = tgpig[0] / CHW;
-    const int64_t chw = tgpig[0] % CHW;
+    const int64_t d = tgpig[0] / args.CHW;
+    const int64_t chw = tgpig[0] % args.CHW;
     const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
     const int64_t HW = tgpig[0] % KHW;
 
     const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
-    if (tpitg_0 >= N) {
+    if (tpitg_0 >= args.N) {
         return;
     }
 
-    const int64_t tpitg_1 = HW / KW;
-    const int64_t tpitg_2 = HW % KW;
+    const int64_t tpitg_1 = HW / args.KW;
+    const int64_t tpitg_2 = HW % args.KW;
 
-    const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
-    const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
+    const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
+    const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
 
     const int64_t offset_dst =
-        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
-        (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
+        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
+        (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
 
     device T * pdst = (device T *) (dst);
 
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
         pdst[offset_dst] = 0.0f;
     } else {
-        const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
-        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+        const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
+        pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
     }
 }
 
@@ -4886,12 +5284,7 @@ typedef void (conv_transpose_1d_t)(
         device const float * src0,
         device const float * src1,
         device        char * dst,
-        constant   int32_t & IC,
-        constant   int32_t & IL,
-        constant   int32_t & K,
-        constant   int32_t & s0,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
+        constant ggml_metal_kargs_conv_transpose_1d & args,
         uint3   tgpig[[threadgroup_position_in_grid]],
         uint3    tgpg[[threadgroups_per_grid]]);
 
@@ -4900,29 +5293,24 @@ kernel void kernel_conv_transpose_1d(
         device const     T * src0,
         device const float * src1,
         device        char * dst,
-        constant   int32_t & IC,
-        constant   int32_t & IL,
-        constant   int32_t & K,
-        constant   int32_t & s0,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
+        constant ggml_metal_kargs_conv_transpose_1d & args,
         uint3   tgpig[[threadgroup_position_in_grid]],
         uint3   tgpg[[threadgroups_per_grid]]) {
 
     float v = 0.0f;
 
-    for (int64_t c = 0; c < IC; c++) {
-        const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
-        const int32_t input_offset = c * IL;
+    for (int64_t c = 0; c < args.IC; c++) {
+        const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
+        const int32_t input_offset = c * args.IL;
 
-        for (int64_t i = 0; i < IL; i++) {
-            if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
-                v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
+        for (int64_t i = 0; i < args.IL; i++) {
+            if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
+                v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
             }
         }
     }
 
-    device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
+    device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
 
     dst_ptr[0] = v;
 }
@@ -4932,12 +5320,7 @@ kernel void kernel_conv_transpose_1d(
     device const float * src0,
     device const float * src1,
     device        char * dst,
-    constant   int32_t & IC,
-    constant   int32_t & IL,
-    constant   int32_t & K,
-    constant   int32_t & s0,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
+    constant ggml_metal_kargs_conv_transpose_1d & args,
     uint3   tgpig[[threadgroup_position_in_grid]],
     uint3    tgpg[[threadgroups_per_grid]]);
 
@@ -4946,38 +5329,14 @@ kernel void kernel_conv_transpose_1d(
     device const half  * src0,
     device const float * src1,
     device        char * dst,
-    constant   int32_t & IC,
-    constant   int32_t & IL,
-    constant   int32_t & K,
-    constant   int32_t & s0,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
+    constant ggml_metal_kargs_conv_transpose_1d & args,
     uint3   tgpig[[threadgroup_position_in_grid]],
     uint3    tgpg[[threadgroups_per_grid]]);
 
 kernel void kernel_upscale_f32(
     device  const char * src0,
     device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant   int64_t & ne0,
-    constant   int64_t & ne1,
-    constant   int64_t & ne2,
-    constant   int64_t & ne3,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
-    constant     float & sf0,
-    constant     float & sf1,
-    constant     float & sf2,
-    constant     float & sf3,
+    constant ggml_metal_kargs_upscale & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
@@ -4986,15 +5345,15 @@ kernel void kernel_upscale_f32(
     const int64_t i2 = tgpig.y;
     const int64_t i1 = tgpig.x;
 
-    const int64_t i03 = i3/sf3;
-    const int64_t i02 = i2/sf2;
-    const int64_t i01 = i1/sf1;
+    const int64_t i03 = i3/args.sf3;
+    const int64_t i02 = i2/args.sf2;
+    const int64_t i01 = i1/args.sf1;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int64_t i00 = i0/sf0;
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int64_t i00 = i0/args.sf0;
 
-        device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-        device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1  +  i0*nb0);
+        device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+        device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1  +  i0*args.nb0);
 
         dst_ptr[0] = src0_ptr[0];
     }
@@ -5003,22 +5362,7 @@ kernel void kernel_upscale_f32(
 kernel void kernel_pad_f32(
     device  const char * src0,
     device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant   int64_t & ne0,
-    constant   int64_t & ne1,
-    constant   int64_t & ne2,
-    constant   int64_t & ne3,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
+    constant ggml_metal_kargs_pad & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
@@ -5031,12 +5375,12 @@ kernel void kernel_pad_f32(
     const int64_t i02 = i2;
     const int64_t i01 = i1;
 
-    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
-    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1);
 
-    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
-        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-            if (i0 < ne00) {
+    if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            if (i0 < args.ne00) {
                 dst_ptr[i0] = src0_ptr[i0];
             } else {
                 dst_ptr[i0] = 0.0f;
@@ -5046,7 +5390,7 @@ kernel void kernel_pad_f32(
         return;
     }
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
         dst_ptr[i0] = 0.0f;
     }
 }
@@ -5054,21 +5398,7 @@ kernel void kernel_pad_f32(
 kernel void kernel_pad_reflect_1d_f32(
     device  const char * src0,
     device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant   int64_t & ne0,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
-    constant   int32_t & p0,
-    constant   int32_t & p1,
+    constant   ggml_metal_kargs_pad_reflect_1d & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3  tgpg[[threadgroups_per_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
@@ -5082,17 +5412,17 @@ kernel void kernel_pad_reflect_1d_f32(
     const int64_t i02 = i2;
     const int64_t i01 = i1;
 
-    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
-    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1);
 
-    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
-        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-            if (i0 < p0) {
-                dst_ptr[i0] = src0_ptr[p0 - i0];
-            } else if (i0 < ne0 - p1) {
-                dst_ptr[i0] = src0_ptr[i0 - p0];
+    if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            if (i0 < args.p0) {
+                dst_ptr[i0] = src0_ptr[args.p0 - i0];
+            } else if (i0 < args.ne0 - args.p1) {
+                dst_ptr[i0] = src0_ptr[i0 - args.p0];
             } else {
-                dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
+                dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
             }
         }
     }
@@ -5145,44 +5475,40 @@ kernel void kernel_unpad_f32(
 
 kernel void kernel_arange_f32(
     device        char * dst,
-    constant   int64_t & ne0,
-    constant   float   & start,
-    constant   float   & step,
+    constant   ggml_metal_kargs_arange & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
 
     device float * dst_ptr = (device float *) dst;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        dst_ptr[i0] = start + step * i0;
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        dst_ptr[i0] = args.start + args.step * i0;
     }
 }
 
 kernel void kernel_timestep_embedding_f32(
     device  const char * src0,
     device        char * dst,
-    constant  uint64_t & nb1,
-    constant  int      & dim,
-    constant  int      & max_period,
+    constant  ggml_metal_kargs_timestep_embedding & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
 
     int i = tgpig.x;
-    device float * embed_data = (device float *)(dst +  i*nb1);
+    device float * embed_data = (device float *)(dst + i*args.nb1);
 
-    int half_ = dim / 2;
+    int half_ = args.dim / 2;
     for (int j = tpitg.x; j < half_; j += ntg.x) {
         float timestep = ((device float *)src0)[i];
-        float freq = (float)exp(-log((float)max_period) * j / half_);
+        float freq = (float)exp(-log((float)args.max_period) * j / half_);
         float arg = timestep * freq;
         embed_data[j        ] = cos(arg);
         embed_data[j + half_] = sin(arg);
     }
 
-    if (dim % 2 != 0 && tpitg.x == 0) {
-        embed_data[dim] = 0.f;
+    if (args.dim % 2 != 0 && tpitg.x == 0) {
+        embed_data[args.dim] = 0.f;
     }
 }
 
@@ -5190,8 +5516,7 @@ kernel void kernel_timestep_embedding_f32(
 typedef void (argsort_t)(
         device const float  * x,
         device     int32_t  * dst,
-        constant   int64_t  & ncols,
-        constant   int64_t  & ncols_pad,
+        constant   ggml_metal_kargs_argsort & args,
         threadgroup int32_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]]);
@@ -5200,8 +5525,7 @@ template
 kernel void kernel_argsort_f32_i32(
         device const float   * x,
         device       int32_t * dst,
-        constant     int64_t & ncols,
-        constant     int64_t & ncols_pad,
+        constant   ggml_metal_kargs_argsort & args,
         threadgroup int32_t  * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]]) {
@@ -5209,9 +5533,9 @@ kernel void kernel_argsort_f32_i32(
     int col = tpitg[0];
     int row = tgpig[1];
 
-    if (col >= ncols_pad) return;
+    if (col >= args.ncols_pad) return;
 
-    device const float   * x_row   = x + row * ncols;
+    device const float   * x_row   = x + row * args.ncols;
     threadgroup int32_t  * dst_row = shared_values;
 
     // initialize indices
@@ -5219,21 +5543,21 @@ kernel void kernel_argsort_f32_i32(
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    for (int k = 2; k <= ncols_pad; k *= 2) {
+    for (int k = 2; k <= args.ncols_pad; k *= 2) {
         for (int j = k / 2; j > 0; j /= 2) {
             int ixj = col ^ j;
             if (ixj > col) {
                 if ((col & k) == 0) {
-                    if (dst_row[col] >= ncols ||
-                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                    if (dst_row[col] >= args.ncols ||
+                        (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
                             x_row[dst_row[col]] > x_row[dst_row[ixj]] :
                             x_row[dst_row[col]] < x_row[dst_row[ixj]]))
                     ) {
                         SWAP(dst_row[col], dst_row[ixj]);
                     }
                 } else {
-                    if (dst_row[ixj] >= ncols ||
-                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                    if (dst_row[ixj] >= args.ncols ||
+                        (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
                             x_row[dst_row[col]] < x_row[dst_row[ixj]] :
                             x_row[dst_row[col]] > x_row[dst_row[ixj]]))
                     ) {
@@ -5246,8 +5570,8 @@ kernel void kernel_argsort_f32_i32(
     }
 
     // copy the result to dst without the padding
-    if (col < ncols) {
-        dst[row * ncols + col] = dst_row[col];
+    if (col < args.ncols) {
+        dst[row * args.ncols + col] = dst_row[col];
     }
 }
 
@@ -5257,9 +5581,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar
 kernel void kernel_leaky_relu_f32(
         device const float * src0,
         device       float * dst,
-        constant     float & slope,
+        constant     ggml_metal_kargs_leaky_relu & args,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
 }
 
 // ref: https://arxiv.org/pdf/2307.08691.pdf
@@ -5286,7 +5610,8 @@ template<
     typename vd4x4_t, // key type in device memory
     short nl_v,
     void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
-    short D,         // head size
+    short DK,        // K head size
+    short DV,        // V head size
     short Q  = 8,    // queries per threadgroup
     short KV = 8,    // key/value processed per each simdgroup
     short C  = 32>   // cache items per threadgroup
@@ -5308,20 +5633,24 @@ kernel void kernel_flash_attn_ext(
     const int iq2 = tgpig[1];
     const int iq1 = tgpig[0]*Q;
 
-    const short D4  = D/4;
-    const short D8  = D/8;
-    const short D16 = D/16;
-    const short NW  = N_SIMDWIDTH;
-    const short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
+    constexpr short DK4  = DK/4;
+    constexpr short DK8  = DK/8;
+    constexpr short DK16 = DK/16;
+    constexpr short DV4  = DV/4;
+    constexpr short DV8  = DV/8;
+    constexpr short DV16 = DV/16;
+
+    constexpr short NW  = N_SIMDWIDTH;
+    constexpr short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
 
     const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
-    const short T  = D + 2*TS; // shared memory size per query in (half)
+    const short T  = DK + 2*TS; // shared memory size per query in (half)
 
-    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +              0*D); // holds the query data
-    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +              0*D); // same as above but in q4_t
-    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +              0*D); // reuse query data for accumulation
-    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +              0*D); // same as above but in o4_t
-    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
+    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +              0*DK); // holds the query data
+    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +              0*DK); // same as above but in q4_t
+    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +              0*DK); // reuse query data for accumulation
+    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +              0*DK); // same as above but in o4_t
+    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
 
     threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
     threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -5330,23 +5659,23 @@ kernel void kernel_flash_attn_ext(
     threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
 
     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    o8x8_t lo[D8];
+    o8x8_t lo[DV8];
 
     // load heads from Q to shared memory
     for (short j = sgitg; j < Q; j += nsg) {
         device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
 
-        for (short i = tiisg; i < D4; i += NW) {
+        for (short i = tiisg; i < DK4; i += NW) {
             if (iq1 + j < args.ne01) {
-                sq4[j*D4 + i] = (q4_t) q4[i];
+                sq4[j*DK4 + i] = (q4_t) q4[i];
             } else {
-                sq4[j*D4 + i] = (q4_t) 0.0f;
+                sq4[j*DK4 + i] = (q4_t) 0.0f;
             }
         }
     }
 
     // zero out lo
-    for (short i = 0; i < D8; ++i) {
+    for (short i = 0; i < DV8; ++i) {
         lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f);
     }
 
@@ -5360,8 +5689,8 @@ kernel void kernel_flash_attn_ext(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     {
-        half S[Q] = { [0 ... Q-1] = 0.0f };
-        half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
+        float S[Q] = { [0 ... Q-1] = 0.0f };
+        float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
 
         // thread indices inside the simdgroup
         // TODO: see if we can utilize quad-group functions for better performance
@@ -5376,22 +5705,15 @@ kernel void kernel_flash_attn_ext(
         const short ikv2 = iq2/(args.ne02/args.ne_12_2);
         const short ikv3 = iq3/(args.ne03/args.ne_12_3);
 
-        // load the queries from shared memory into local memory
-        q8x8_t mq[D8];
-
-        for (short i = 0; i < D8; ++i) {
-            simdgroup_load(mq[i], sq + i*8, D);
-        }
-
         const bool has_mask = mask != q;
 
-        half slope = 1.0f;
+        float slope = 1.0f;
 
         // ALiBi
         if (args.max_bias > 0.0f) {
             const short h = iq2;
 
-            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const float base = h < args.n_head_log2 ? args.m0 : args.m1;
             const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
             slope = pow(base, exph);
@@ -5407,14 +5729,14 @@ kernel void kernel_flash_attn_ext(
 
             if (has_mask) {
                 // used to detect blocks full of -INF
-                half smax = -INFINITY;
+                float smax = -INFINITY;
 
                 // load the mask in shared memory
                 #pragma unroll(Q)
                 for (short j = 0; j < Q; ++j) {
                     device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
 
-                    const half m = pm[ic + tiisg];
+                    const float m = pm[ic + tiisg];
 
                     ss[j*TS + C + tiisg] = m;
                     smax = max(smax, m);
@@ -5435,20 +5757,22 @@ kernel void kernel_flash_attn_ext(
                     // this is compile-time check, so it does not have runtime overhead
                     if (is_same::value) {
                         // we can read directly from global memory
-                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
 
-                        #pragma unroll(D8)
-                        for (short i = 0; i < D8; ++i) {
+                        #pragma unroll(DK8)
+                        for (short i = 0; i < DK8; ++i) {
                             k8x8_t mk;
-                            simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+                            simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
 
-                            simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+                            q8x8_t mq;
+                            simdgroup_load(mq, sq + i*8, DK);
+                            simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
                         }
                     } else {
-                        for (short ii = 0; ii < D16; ii += 4) {
-                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        for (short ii = 0; ii < DK16; ii += 4) {
+                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
 
-                            if (D16%4 == 0) {
+                            if (DK16%4 == 0) {
                                 // the head is evenly divisible by 4*16 = 64, so no need for bound checks
                                 {
                                     k4x4_t tmp;
@@ -5461,15 +5785,18 @@ kernel void kernel_flash_attn_ext(
                                 #pragma unroll(4)
                                 for (short k = 0; k < 4; ++k) {
                                     k8x8_t mk;
+                                    q8x8_t mq;
 
                                     simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
 
                                     simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
                                 }
                             } else {
-                                if (ii + tx < D16) {
+                                if (ii + tx < DK16) {
                                     k4x4_t tmp;
                                     deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
                                     sk4x4[4*ty + tx] = tmp;
@@ -5477,14 +5804,17 @@ kernel void kernel_flash_attn_ext(
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
-                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                for (short k = 0; k < 4 && ii + k < DK16; ++k) {
                                     k8x8_t mk;
+                                    q8x8_t mq;
 
                                     simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
 
                                     simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
                                 }
                             }
                         }
@@ -5502,10 +5832,10 @@ kernel void kernel_flash_attn_ext(
             // online softmax
             {
                 for (ushort j = 0; j < Q; ++j) {
-                    const half m = M[j];
+                    const float m = M[j];
 
                     // scale and apply the logitcap / mask
-                    half s = ss[j*TS + tiisg]*args.scale;
+                    float s = ss[j*TS + tiisg]*args.scale;
 
                     if (args.logit_softcap != 0.0f) {
                         s = args.logit_softcap*precise::tanh(s);
@@ -5516,8 +5846,8 @@ kernel void kernel_flash_attn_ext(
 
                     M[j] = simd_max(max(M[j], s));
 
-                    const half ms = exp(m - M[j]);
-                    const half vs = exp(s - M[j]);
+                    const float ms = exp(m - M[j]);
+                    const float vs = exp(s - M[j]);
 
                     S[j] = S[j]*ms + simd_sum(vs);
 
@@ -5536,8 +5866,8 @@ kernel void kernel_flash_attn_ext(
                 s8x8_t mm;
                 simdgroup_load(mm, ss + 2*C, TS, 0, false);
 
-                #pragma unroll(D8)
-                for (short i = 0; i < D8; ++i) {
+                #pragma unroll(DV8)
+                for (short i = 0; i < DV8; ++i) {
                     simdgroup_multiply(lo[i], mm, lo[i]);
                 }
             }
@@ -5550,20 +5880,20 @@ kernel void kernel_flash_attn_ext(
 
                     if (is_same::value) {
                         // we can read directly from global memory
-                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
 
-                        #pragma unroll(D8)
-                        for (short i = 0; i < D8; ++i) {
+                        #pragma unroll(DV8)
+                        for (short i = 0; i < DV8; ++i) {
                             v8x8_t mv;
-                            simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
+                            simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
 
                             simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
                         }
                     } else {
-                        for (short ii = 0; ii < D16; ii += 4) {
-                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        for (short ii = 0; ii < DV16; ii += 4) {
+                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
 
-                            if (D16%4 == 0) {
+                            if (DV16%4 == 0) {
                                 // no need for bound checks
                                 {
                                     v4x4_t tmp;
@@ -5584,7 +5914,7 @@ kernel void kernel_flash_attn_ext(
                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
                                 }
                             } else {
-                                if (ii + tx < D16) {
+                                if (ii + tx < DV16) {
                                     v4x4_t tmp;
                                     deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
                                     sv4x4[4*ty + tx] = tmp;
@@ -5592,7 +5922,7 @@ kernel void kernel_flash_attn_ext(
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
-                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                for (short k = 0; k < 4 && ii + k < DV16; ++k) {
                                     v8x8_t mv;
 
                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
@@ -5619,15 +5949,15 @@ kernel void kernel_flash_attn_ext(
 
     // reduce the warps sequentially
     for (ushort sg = 1; sg < nsg; ++sg) {
-        half S = { 0.0f };
-        half M = { -__FLT16_MAX__/2 };
+        float S = { 0.0f };
+        float M = { -__FLT16_MAX__/2 };
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         // each simdgroup stores its output to shared memory, reusing sq
         if (sgitg == sg) {
-            for (short i = 0; i < D8; ++i) {
-                simdgroup_store(lo[i], so + i*8, D, 0, false);
+            for (short i = 0; i < DV8; ++i) {
+                simdgroup_store(lo[i], so + i*8, DV, 0, false);
             }
         }
 
@@ -5636,16 +5966,16 @@ kernel void kernel_flash_attn_ext(
         // the first simdgroup accumulates the results from the other simdgroups
         if (sgitg == 0) {
             for (short j = 0; j < Q; ++j) {
-                const half S0 = ss[j*TS +         0];
-                const half S1 = ss[j*TS + sg*SH + 0];
+                const float S0 = ss[j*TS +         0];
+                const float S1 = ss[j*TS + sg*SH + 0];
 
-                const half M0 = ss[j*TS +         1];
-                const half M1 = ss[j*TS + sg*SH + 1];
+                const float M0 = ss[j*TS +         1];
+                const float M1 = ss[j*TS + sg*SH + 1];
 
                 M = max(M0, M1);
 
-                const half ms0 = exp(M0 - M);
-                const half ms1 = exp(M1 - M);
+                const float ms0 = exp(M0 - M);
+                const float ms1 = exp(M1 - M);
 
                 S = S0*ms0 + S1*ms1;
 
@@ -5666,11 +5996,11 @@ kernel void kernel_flash_attn_ext(
                 simdgroup_load(ms0, ss + 2*C,         TS, 0, false);
                 simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
 
-                #pragma unroll(D8)
-                for (short i = 0; i < D8; ++i) {
+                #pragma unroll(DV8)
+                for (short i = 0; i < DV8; ++i) {
                     o8x8_t t;
 
-                    simdgroup_load    (t, so + i*8, D, 0, false);
+                    simdgroup_load    (t, so + i*8, DV, 0, false);
                     simdgroup_multiply(t, ms1, t);
 
                     simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@@ -5681,8 +6011,8 @@ kernel void kernel_flash_attn_ext(
 
     // store result to shared memory (reuse sq)
     if (sgitg == 0) {
-        for (short i = 0; i < D8; ++i) {
-            simdgroup_store(lo[i], so + i*8, D, 0, false);
+        for (short i = 0; i < DV8; ++i) {
+            simdgroup_store(lo[i], so + i*8, DV, 0, false);
         }
     }
 
@@ -5693,8 +6023,8 @@ kernel void kernel_flash_attn_ext(
         for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
             const float S = ss[j*TS + 0];
 
-            for (short i = tiisg; i < D4; i += NW) {
-                dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
+            for (short i = tiisg; i < DV4; i += NW) {
+                dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
             }
         }
     }
@@ -5711,80 +6041,94 @@ kernel void kernel_flash_attn_ext(
     float,          simdgroup_float8x8, \
     half,  half4,   simdgroup_half8x8
 
-typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t;
+typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t;
 
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h192")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 #endif
 
-template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 
-template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 
-template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 
-template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 
-template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #undef FA_TYPES
 
 template<
-    typename q4_t,    // query types in shared memory
-    typename q4x4_t,
-    typename k4x4_t,  // key types in shared memory
-    typename v4x4_t,  // value types in shared memory
-    typename qk_t,    // Q*K types
-    typename s_t,     // soft-max types
+    typename q4_t,  // query types in shared memory
+    typename k4_t,  // key types in shared memory
+    typename v4_t,  // value types in shared memory
+    typename qk_t,  // Q*K types
+    typename s_t,   // soft-max types
     typename s4_t,
-    typename s4x4_t,
-    typename o4x4_t,  // attention accumulation types
-    typename kd4x4_t, // key type in device memory
+    typename o4_t,  // attention accumulation types
+    typename kd4_t, // key type in device memory
     short nl_k,
-    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
-    typename vd4x4_t, // key type in device memory
+    void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
+    typename vd4_t, // key type in device memory
     short nl_v,
-    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
-    short D,         // head size
-    short Q  = 1,    // queries per threadgroup
-    short C  = 32>   // cache items per threadgroup
+    void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
+    short DK,       // K head size
+    short DV,       // V head size
+    short NE = 4,   // head elements per thread
+    short Q  = 1,   // queries per threadgroup
+    short C  = 32>  // cache items per threadgroup
 kernel void kernel_flash_attn_ext_vec(
         constant ggml_metal_kargs_flash_attn_ext & args,
         device const char * q,
@@ -5803,29 +6147,28 @@ kernel void kernel_flash_attn_ext_vec(
     const int iq2 = tgpig[1];
     const int iq1 = tgpig[0];
 
-    const short D4  = D/4;
-    const short D16 = D/16;
-    const short NW  = N_SIMDWIDTH;
-    const short NL  = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
-    const short SH  = 2*C;  // shared memory per simdgroup
+    constexpr short DK4 = DK/4;
+    constexpr short DV4 = DV/4;
+    constexpr short NW  = N_SIMDWIDTH;
+    constexpr short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
+    constexpr short SH  = 4*C;   // shared memory per simdgroup
 
-    const short T = D + nsg*SH; // shared memory size per query in (half)
+    const short T = DK + nsg*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t    * sq    = (threadgroup q_t    *) (shmem_f16 +                0*D); // holds the query data
-    threadgroup q4_t   * sq4   = (threadgroup q4_t   *) (shmem_f16 +                0*D); // same as above but in q4_t
-    threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 +                0*D); // same as above but in q4x4_t
-    threadgroup s_t    * ss    = (threadgroup s_t    *) (shmem_f16 + sgitg*SH     + Q*D); // scratch buffer for attention
-    threadgroup s4_t   * ss4   = (threadgroup s4_t   *) (shmem_f16 + sgitg*SH     + Q*D); // same as above but in s4_t
-    threadgroup half   * sm    = (threadgroup half   *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
-    threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D      + Q*T); // scratch buffer for the results
+  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                  0*DK); // holds the query data
+    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                  0*DK); // same as above but in q4_t
+    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 + sgitg*SH       + Q*DK); // scratch buffer for attention
+    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 + sgitg*SH       + Q*DK); // same as above but in s4_t
+    threadgroup float * sm  = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
+    threadgroup o4_t  * sr4 = (threadgroup o4_t  *) (shmem_f16 + sgitg*DV       + Q*T);  // scratch buffer for the results
 
-    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    o4x4_t lo[D16/NL];
+    // store the result for all queries in local memory (the O matrix from the paper)
+    o4_t lo[DV4/NL];
 
     // load heads from Q to shared memory
     device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
 
-    for (short i = tiisg; i < D4; i += NW) {
+    for (short i = tiisg; i < DK4; i += NW) {
         if (iq1 < args.ne01) {
             sq4[i] = (q4_t) q4[i];
         } else {
@@ -5834,8 +6177,8 @@ kernel void kernel_flash_attn_ext_vec(
     }
 
     // zero out lo
-    for (short i = 0; i < D16/NL; ++i) {
-        lo[i] = (o4x4_t) 0.0f;
+    for (short i = 0; i < DV4/NL; ++i) {
+        lo[i] = (o4_t) 0.0f;
     }
 
     // zero out shared memory SH
@@ -5846,8 +6189,8 @@ kernel void kernel_flash_attn_ext_vec(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     {
-        half S = 0.0f;
-        half M = -__FLT16_MAX__/2;
+        float S = 0.0f;
+        float M = -__FLT16_MAX__/2;
 
         // thread indices inside the simdgroup
         const short tx = tiisg%NL;
@@ -5860,26 +6203,18 @@ kernel void kernel_flash_attn_ext_vec(
         const short ikv2 = iq2/(args.ne02/args.ne_12_2);
         const short ikv3 = iq3/(args.ne03/args.ne_12_3);
 
-        // load the queries from shared memory into local memory
-        q4x4_t mq[D16/NL];
-
-        #pragma unroll(D16/NL)
-        for (short ii = 0; ii < D16; ii += NL) {
-            mq[ii/NL] = sq4x4[ii + tx];
-        }
-
         const bool has_mask = mask != q;
 
         // pointer to the mask
         device const half * pm = (device const half *) (mask + iq1*args.nb31);
 
-        half slope = 1.0f;
+        float slope = 1.0f;
 
         // ALiBi
         if (args.max_bias > 0.0f) {
             const short h = iq2;
 
-            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const float base = h < args.n_head_log2 ? args.m0 : args.m1;
             const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
             slope = pow(base, exph);
@@ -5899,43 +6234,56 @@ kernel void kernel_flash_attn_ext_vec(
 
             // Q*K^T
             {
-                // each simdgroup processes 1 query and 4 (NW/NL) keys
-                for (short cc = 0; cc < C/4; ++cc) {
-                    qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
+                // each simdgroup processes 1 query and NE (NW/NL) head elements
+                for (short cc = 0; cc < C/NE; ++cc) {
+                    qk_t mqk = 0.0f;
 
-                    device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                    device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
 
-                    #pragma unroll(D16/NL)
-                    for (short ii = 0; ii < D16; ii += NL) {
+                    #pragma unroll(DK4/NL)
+                    for (short ii = 0; ii < DK4; ii += NL) {
                         const short i = ii + tx;
 
-                        k4x4_t mk;
-                        deq_k(pk + i/nl_k, i%nl_k, mk);
+                        k4_t mk;
+                        deq_k_t4(pk + i/nl_k, i%nl_k, mk);
 
                         // note: this is less precise than the version below
-                        //mqka[0] += dot(mq[ii/NL][0], mk[0]);
-                        //mqka[1] += dot(mq[ii/NL][1], mk[1]);
-                        //mqka[2] += dot(mq[ii/NL][2], mk[2]);
-                        //mqka[3] += dot(mq[ii/NL][3], mk[3]);
+                        //mqka[0] += dot(mq[0], mk[0]);
+                        //mqka[1] += dot(mq[1], mk[1]);
+                        //mqka[2] += dot(mq[2], mk[2]);
+                        //mqka[3] += dot(mq[3], mk[3]);
 
-                        mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
-                        mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
-                        mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
-                        mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
+                        //q4x4_t mq = sq4x4[i];
+                        //mqka[0] += dot((float4) mq[0], (float4) mk[0]);
+                        //mqka[1] += dot((float4) mq[1], (float4) mk[1]);
+                        //mqka[2] += dot((float4) mq[2], (float4) mk[2]);
+                        //mqka[3] += dot((float4) mq[3], (float4) mk[3]);
+
+                        mqk += dot((float4) mk, (float4) sq4[i]);
                     }
 
-                    qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
+                    static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
 
-                    // simdgroup reduce
+                    // simdgroup reduce (NE = 4)
                     // [ 0 ..  7] -> [ 0]
                     // [ 8 .. 15] -> [ 8]
                     // [16 .. 23] -> [16]
                     // [24 .. 31] -> [24]
-                  //mqk += simd_shuffle_down(mqk, 16);
-                  //mqk += simd_shuffle_down(mqk,  8);
-                    mqk += simd_shuffle_down(mqk,  4);
-                    mqk += simd_shuffle_down(mqk,  2);
-                    mqk += simd_shuffle_down(mqk,  1);
+                    if (NE <= 1) {
+                        mqk += simd_shuffle_down(mqk, 16);
+                    }
+                    if (NE <= 2) {
+                        mqk += simd_shuffle_down(mqk,  8);
+                    }
+                    if (NE <= 4) {
+                        mqk += simd_shuffle_down(mqk,  4);
+                    }
+                    if (NE <= 8) {
+                        mqk += simd_shuffle_down(mqk,  2);
+                    }
+                    if (NE <= 16) {
+                        mqk += simd_shuffle_down(mqk,  1);
+                    }
 
                     // mqk = mqk*scale + mask*slope
                     if (tx == 0) {
@@ -5945,9 +6293,9 @@ kernel void kernel_flash_attn_ext_vec(
                             mqk = args.logit_softcap*precise::tanh(mqk);
                         }
 
-                        mqk += sm[4*cc + ty]*slope;
+                        mqk += sm[NE*cc + ty]*slope;
 
-                        ss[4*cc + ty] = mqk;
+                        ss[NE*cc + ty] = mqk;
                     }
                 }
             }
@@ -5956,13 +6304,13 @@ kernel void kernel_flash_attn_ext_vec(
 
             // online softmax
             {
-                const half m = M;
-                const half s = ss[tiisg];
+                const float m = M;
+                const float s = ss[tiisg];
 
                 M = simd_max(max(M, s));
 
-                const half ms = exp(m - M);
-                const half vs = exp(s - M);
+                const float ms = exp(m - M);
+                const float vs = exp(s - M);
 
                 S = S*ms + simd_sum(vs);
 
@@ -5970,8 +6318,8 @@ kernel void kernel_flash_attn_ext_vec(
                 ss[tiisg] = vs;
 
                 // O = diag(ms)*O
-                #pragma unroll(D16/NL)
-                for (short ii = 0; ii < D16; ii += NL) {
+                #pragma unroll(DV4/NL)
+                for (short ii = 0; ii < DV4; ii += NL) {
                     lo[ii/NL] *= ms;
                 }
             }
@@ -5980,19 +6328,20 @@ kernel void kernel_flash_attn_ext_vec(
 
             // O = O + (Q*K^T)*V
             {
-                for (short cc = 0; cc < C/4; ++cc) {
-                    device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                //#pragma unroll(C/NE)
+                for (short cc = 0; cc < C/NE; ++cc) {
+                    device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
 
-                    const s4x4_t ms(ss[4*cc + ty]);
+                    const s4_t ms(ss[NE*cc + ty]);
 
-                    #pragma unroll(D16/NL)
-                    for (short ii = 0; ii < D16; ii += NL) {
+                    #pragma unroll(DV4/NL)
+                    for (short ii = 0; ii < DV4; ii += NL) {
                         const short i = ii + tx;
 
-                        v4x4_t mv;
-                        deq_v(pv4 + i/nl_v, i%nl_v, mv);
+                        v4_t mv;
+                        deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
 
-                        lo[ii/NL] += mv*ms;
+                        lo[ii/NL] += o4_t(float4(mv)*float4(ms));
                     }
                 }
             }
@@ -6005,7 +6354,7 @@ kernel void kernel_flash_attn_ext_vec(
         }
     }
 
-    // simdgroup reduce
+    // simdgroup reduce (NE = 4)
     // [ 0,  8, 16, 24] -> [ 0]
     // [ 1,  9, 17, 25] -> [ 1]
     // [ 2, 10, 18, 26] -> [ 2]
@@ -6014,37 +6363,48 @@ kernel void kernel_flash_attn_ext_vec(
     // [ 5, 13, 21, 29] -> [ 5]
     // [ 6, 14, 22, 30] -> [ 6]
     // [ 7, 15, 23, 31] -> [ 7]
-    for (short ii = 0; ii < D16; ii += NL) {
-        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
-        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
+    for (short ii = 0; ii < DV4; ii += NL) {
+        if (NE > 1) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
+        }
 
-        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
-        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
+        if (NE > 2) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
+        }
 
-        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
-        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
+        if (NE > 4) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
+        }
 
-        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
-        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
+        if (NE > 8) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
+        }
+
+        if (NE > 16) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
+        }
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     // store results to shared memory
-    for (short i = tiisg; i < D16; i += NL) {
-        sr4x4[i] = lo[i/NL];
+    for (short i = tiisg; i < DV4; i += NL) {
+        sr4[i] = lo[i/NL];
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -6052,18 +6412,18 @@ kernel void kernel_flash_attn_ext_vec(
     // parallel reduce
     for (short r = nsg/2; r > 0; r >>= 1) {
         if (sgitg < r) {
-            const half S0 = ss[       0];
-            const half S1 = ss[r*SH + 0];
+            const float S0 = ss[           0];
+            const float S1 = ss[r*(SH/2) + 0];
 
-            const half M0 = ss[       1];
-            const half M1 = ss[r*SH + 1];
+            const float M0 = ss[           1];
+            const float M1 = ss[r*(SH/2) + 1];
 
-            const half M = max(M0, M1);
+            const float M = max(M0, M1);
 
-            const half ms0 = exp(M0 - M);
-            const half ms1 = exp(M1 - M);
+            const float ms0 = exp(M0 - M);
+            const float ms1 = exp(M1 - M);
 
-            const half S = S0*ms0 + S1*ms1;
+            const float S = S0*ms0 + S1*ms1;
 
             if (tiisg == 0) {
                 ss[0] = S;
@@ -6071,22 +6431,22 @@ kernel void kernel_flash_attn_ext_vec(
             }
 
             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
-            for (short i = tiisg; i < D16; i += NW) {
-                sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
+            for (short i = tiisg; i < DV4; i += NW) {
+                sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
             }
         }
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    device float4x4 * dst44 = (device float4x4 *) dst;
+    device float4 * dst4 = (device float4 *) dst;
 
     // final rescale with 1/S and store to global memory
     if (sgitg == 0) {
         const float S = ss[0];
 
-        for (short i = tiisg; i < D16; i += NW) {
-            dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
+        for (short i = tiisg; i < DV4; i += NW) {
+            dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
         }
     }
 }
@@ -6095,34 +6455,54 @@ kernel void kernel_flash_attn_ext_vec(
 //       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
 //
 #define FA_TYPES \
-           half4,  half4x4, \
-                   half4x4, \
-                   half4x4, \
-    float,                  \
-    half,  half4,  half4x4, \
-                   half4x4
+           half4,  \
+           half4,  \
+           half4,  \
+    float,         \
+    float, float4, \
+           half4
 
-typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;
+typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 
 #undef FA_TYPES
 
@@ -6497,7 +6877,7 @@ kernel void kernel_cpy_f32_iq4_nl(
         float amax = 0.0f; // absolute max
         float max  = 0.0f;
 
-        for (int j = 0; j < QK4_0; j++) {
+        for (int j = 0; j < QK4_NL; j++) {
             const float v = src[j];
             if (amax < fabs(v)) {
                 amax = fabs(v);
@@ -6605,7 +6985,7 @@ kernel void kernel_concat(
     }
 }
 
-template
+template
 void kernel_mul_mv_q2_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6621,7 +7001,7 @@ void kernel_mul_mv_q2_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6633,20 +7013,19 @@ void kernel_mul_mv_q2_K_f32_impl(
     device const float      * y = (device const float      *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
-    const int ix = tiisg/8;  // 0...3
-    const int it = tiisg%8;  // 0...7
-    const int iq = it/4;     // 0 or 1
-    const int ir = it%4;     // 0...3
-    const int is = (8*ir)/16;// 0 or 1
+    const short ix = tiisg/8;  // 0...3
+    const short it = tiisg%8;  // 0...7
+    const short iq = it/4;     // 0 or 1
+    const short ir = it%4;     // 0...3
+    const short is = (8*ir)/16;// 0 or 1
 
     device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
 
     for (int ib = ix; ib < nb; ib += 4) {
-
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int i = 0; i < 8; ++i) {
+        for (short i = 0; i < 8; ++i) {
             yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
             yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
             yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
@@ -6657,7 +7036,7 @@ void kernel_mul_mv_q2_K_f32_impl(
         device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
         device const half     * dh = &x[ib].d;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
             for (int i = 0; i < 8; i += 2) {
@@ -6688,10 +7067,10 @@ void kernel_mul_mv_q2_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
@@ -6706,10 +7085,10 @@ kernel void kernel_mul_mv_q2_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_q3_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6726,7 +7105,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6742,13 +7121,12 @@ void kernel_mul_mv_q3_K_f32_impl(
     //const uint16_t kmask1 = 0x3030;
     //const uint16_t kmask2 = 0x0f0f;
 
-    const int tid = tiisg/4;
-    const int ix  = tiisg%4;
-    const int ip  = tid/4;          // 0 or 1
-    const int il  = 2*((tid%4)/2);  // 0 or 2
-    const int ir  = tid%2;
-    const int n   = 8;
-    const int l0  = n*ir;
+    const short tid = tiisg/4;
+    const short ix  = tiisg%4;
+    const short ip  = tid/4;          // 0 or 1
+    const short il  = 2*((tid%4)/2);  // 0 or 2
+    const short ir  = tid%2;
+    const short l0  = 8*ir;
 
     // One would think that the Metal compiler would figure out that ip and il can only have
     // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
@@ -6773,8 +7151,8 @@ void kernel_mul_mv_q3_K_f32_impl(
     const uint16_t s_shift1 = 4*ip;
     const uint16_t s_shift2 = s_shift1 + il;
 
-    const int q_offset = 32*ip + l0;
-    const int y_offset = 128*ip + 32*il + l0;
+    const short q_offset = 32*ip + l0;
+    const short y_offset = 128*ip + 32*il + l0;
 
     device const float * y1 = yy + ix*QK_K + y_offset;
 
@@ -6782,10 +7160,11 @@ void kernel_mul_mv_q3_K_f32_impl(
     thread uint16_t * scales16 = (thread uint16_t *)&scales32;
     thread const int8_t * scales = (thread const int8_t *)&scales32;
 
-    float sumf1[2] = {0.f};
-    float sumf2[2] = {0.f};
+    float sumf1[nr0] = {0.f};
+    float sumf2[nr0] = {0.f};
+
     for (int i = ix; i < nb; i += 4) {
-        for (int l = 0; l < 8; ++l) {
+        for (short l = 0; l < 8; ++l) {
             yl[l+ 0] = y1[l+ 0];
             yl[l+ 8] = y1[l+16];
             yl[l+16] = y1[l+32];
@@ -6797,7 +7176,7 @@ void kernel_mul_mv_q3_K_f32_impl(
         device const uint16_t * a = (device const uint16_t *)(x[i].scales);
         device const half * dh = &x[i].d;
 
-        for (int row = 0; row < 2; ++row) {
+        for (short row = 0; row < nr0; ++row) {
             const float d_all = (float)dh[0];
 
             scales16[0] = a[4];
@@ -6808,7 +7187,7 @@ void kernel_mul_mv_q3_K_f32_impl(
             scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
 
             float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
-            for (int l = 0; l < n; l += 2) {
+            for (short l = 0; l < 8; l += 2) {
                 const int32_t qs = q[l/2];
                 s1 += yl[l+0] * (qs & qm[il/2][0]);
                 s2 += yl[l+1] * (qs & qm[il/2][1]);
@@ -6823,7 +7202,7 @@ void kernel_mul_mv_q3_K_f32_impl(
             sumf2[row] += d2 * (scales[2] - 32);
 
             s1 = s2 = s3 = s4 = s5 = s6 = 0;
-            for (int l = 0; l < n; l += 2) {
+            for (short l = 0; l < 8; l += 2) {
                 const int32_t qs = q[l/2+8];
                 s1 += yl[l+8] * (qs & qm[il/2][0]);
                 s2 += yl[l+9] * (qs & qm[il/2][1]);
@@ -6846,7 +7225,7 @@ void kernel_mul_mv_q3_K_f32_impl(
         y1 += 4 * QK_K;
     }
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < nr0; ++row) {
         const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
         sumf1[row] = simd_sum(sumf);
     }
@@ -6854,7 +7233,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
     if (tiisg == 0) {
-        for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+        for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
             dst_f32[first_row + row] = sumf1[row];
         }
     }
@@ -6870,10 +7249,10 @@ kernel void kernel_mul_mv_q3_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_q4_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6883,22 +7262,22 @@ void kernel_mul_mv_q4_K_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
-
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int ix = tiisg/8;  // 0...3
-    const int it = tiisg%8;  // 0...7
-    const int iq = it/4;     // 0 or 1
-    const int ir = it%4;     // 0...3
+    const short ix = tiisg/8;  // 0...3
+    const short it = tiisg%8;  // 0...7
+    const short iq = it/4;     // 0 or 1
+    const short ir = it%4;     // 0...3
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int first_row = r0 * N_DST;
+
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6911,7 +7290,8 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     float yl[16];
     float yh[16];
-    float sumf[N_DST]={0.f}, all_sum;
+
+    float sumf[nr0]={0.f};
 
     device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
 
@@ -6920,7 +7300,8 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     for (int ib = ix; ib < nb; ib += 4) {
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int i = 0; i < 8; ++i) {
+
+        for (short i = 0; i < 8; ++i) {
             yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
             yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
             yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
@@ -6931,7 +7312,7 @@ void kernel_mul_mv_q4_K_f32_impl(
         device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
         device const half     * dh = &x[ib].d;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             sc16[0] = sc[0] & kmask1;
             sc16[1] = sc[2] & kmask1;
             sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
@@ -6941,19 +7322,21 @@ void kernel_mul_mv_q4_K_f32_impl(
 
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
-            for (int i = 0; i < 8; i += 2) {
-                acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
-                acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
-                acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
-                acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
-                acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
-                acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
-                acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
-                acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+
+            for (short i = 0; i < 4; ++i) {
+                acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
+                acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
+                acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
+                acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
+                acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
+                acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
+                acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
+                acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
             }
 
             float dall = dh[0];
             float dmin = dh[1];
+
             sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
                                  (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
                                  (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
@@ -6970,10 +7353,10 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
@@ -6988,10 +7371,10 @@ kernel void kernel_mul_mv_q4_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_q5_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -7008,7 +7391,7 @@ void kernel_mul_mv_q5_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7019,7 +7402,7 @@ void kernel_mul_mv_q5_K_f32_impl(
     device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
     device const float     * yy = (device const float      *) (src1 + offset1);
 
-    float sumf[2]={0.f};
+    float sumf[nr0]={0.f};
 
     float yl[16], yh[16];
 
@@ -7027,15 +7410,14 @@ void kernel_mul_mv_q5_K_f32_impl(
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int tid = tiisg/4;
-    const int ix  = tiisg%4;
-    const int iq  = tid/4;
-    const int ir  = tid%4;
-    const int n   = 8;
+    const short tid = tiisg/4;
+    const short ix  = tiisg%4;
+    const short iq  = tid/4;
+    const short ir  = tid%4;
 
-    const int l0 = n*ir;
-    const int q_offset = 32*iq + l0;
-    const int y_offset = 64*iq + l0;
+    const short l0 = 8*ir;
+    const short q_offset = 32*iq + l0;
+    const short y_offset = 64*iq + l0;
 
     const uint8_t hm1 = 1u << (2*iq);
     const uint8_t hm2 = hm1 << 1;
@@ -7055,14 +7437,14 @@ void kernel_mul_mv_q5_K_f32_impl(
 
         device const float * y2 = y1 + 128;
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int l = 0; l < 8; ++l) {
+        for (short l = 0; l < 8; ++l) {
             yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
             yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
             yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
             yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
         }
 
-        for (int row = 0; row < 2; ++row) {
+        for (short row = 0; row < nr0; ++row) {
             device const uint8_t * q2 = q1 + 64;
 
             sc16[0] = a[0] & kmask1;
@@ -7072,7 +7454,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
             float4 acc1 = {0.f};
             float4 acc2 = {0.f};
-            for (int l = 0; l < n; ++l) {
+            for (short l = 0; l < 8; ++l) {
                 uint8_t h = qh[l];
                 acc1[0] += yl[l+0] * (q1[l] & 0x0F);
                 acc1[1] += yl[l+8] * (q1[l] & 0xF0);
@@ -7102,7 +7484,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = tot;
@@ -7120,10 +7502,10 @@ kernel void kernel_mul_mv_q5_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template 
+template
 void kernel_mul_mv_q6_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -7145,62 +7527,77 @@ void kernel_mul_mv_q6_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int row = 2*r0 + sgitg;
-
-    if (row >= args.ne0) {
-        return;
-    }
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
-    const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
-    const uint64_t offset1 =  r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
     device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
     device const float     * yy = (device const float      *) (src1 + offset1);
 
-    float sumf = 0;
+    float sumf[nr0] = { 0.f };
 
-    const int tid  = tiisg/2;
-    const int ix   = tiisg%2;
-    const int ip   = tid/8;         // 0 or 1
-    const int il   = tid%8;
-    const int n    = 4;
-    const int l0   = n*il;
-    const int is   = 8*ip + l0/16;
+    float yl[16];
 
-    const int y_offset = 128*ip + l0;
-    const int q_offset_l = 64*ip + l0;
-    const int q_offset_h = 32*ip + l0;
+    const short tid = tiisg/2;
+    const short ix  = tiisg%2;
+    const short ip  = tid/8;         // 0 or 1
+    const short il  = tid%8;
+    const short l0  = 4*il;
+    const short is  = 8*ip + l0/16;
+
+    const short y_offset   = 128*ip + l0;
+    const short q_offset_l =  64*ip + l0;
+    const short q_offset_h =  32*ip + l0;
 
     for (int i = ix; i < nb; i += 2) {
         device const uint8_t * q1 = x[i].ql + q_offset_l;
         device const uint8_t * q2 = q1 + 32;
         device const uint8_t * qh = x[i].qh + q_offset_h;
         device const int8_t  * sc = x[i].scales + is;
+        device const half    * dh = &x[i].d;
 
         device const float * y = yy + i * QK_K + y_offset;
 
-        const float dall = x[i].d;
-
-        float4 sums = {0.f, 0.f, 0.f, 0.f};
-        for (int l = 0; l < n; ++l) {
-            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
-            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
-            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
-            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+        for (short l = 0; l < 4; ++l) {
+            yl[4*l + 0] = y[l +  0];
+            yl[4*l + 1] = y[l + 32];
+            yl[4*l + 2] = y[l + 64];
+            yl[4*l + 3] = y[l + 96];
         }
 
-        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+        for (short row = 0; row < nr0; ++row) {
+            const float dall = dh[0];
 
+            float4 sums = {0.f, 0.f, 0.f, 0.f};
+
+            for (short l = 0; l < 4; ++l) {
+                sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+                sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+                sums[2] += yl[4*l + 2] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+                sums[3] += yl[4*l + 3] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+            }
+
+            sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+            q1 += args.nb01;
+            q2 += args.nb01;
+            qh += args.nb01;
+            sc += args.nb01;
+            dh += args.nb01/2;
+        }
     }
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    const float tot = simd_sum(sumf);
-    if (tiisg == 0) {
-        dst_f32[row] = tot;
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = sum_all;
+        }
     }
 }
 
@@ -7214,12 +7611,12 @@ kernel void kernel_mul_mv_q6_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 // ======================= "True" 2-bit
 
-template
+template
 void kernel_mul_mv_iq2_xxs_f32_impl(
         args_t args,
         device const char * src0,
@@ -7235,7 +7632,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7247,7 +7644,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     device const float         * y = (device const float         *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -7268,8 +7665,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -7280,18 +7676,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
         device const uint16_t * q2 = xr->qs + 4 * ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             device const uint8_t * aux8 = (device const uint8_t *)q2;
             const uint32_t aux32 = q2[2] | (q2[3] << 16);
             const float d = db * (0.5f + (aux32 >> 28));
 
             float sum = 0;
-            for (int l = 0; l < 4; ++l) {
+            for (short l = 0; l < 4; ++l) {
                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
                 const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
@@ -7306,10 +7701,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = sum_all * 0.25f;
         }
     }
 }
@@ -7324,10 +7719,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_iq2_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -7343,7 +7738,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7355,7 +7750,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     device const float        * y = (device const float        *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -7376,8 +7771,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -7389,8 +7783,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
         device const uint8_t  * sc = xr->scales + ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const uint8_t ls1 = sc[0] & 0xf;
             const uint8_t ls2 = sc[0] >>  4;
@@ -7398,17 +7791,17 @@ void kernel_mul_mv_iq2_xs_f32_impl(
             const float d2 = db * (0.5f + ls2);
 
             float sum1 = 0, sum2 = 0;
-            for (int l = 0; l < 2; ++l) {
+            for (short l = 0; l < 2; ++l) {
                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
                 const uint8_t signs = ssigns[(q2[l] >> 9)];
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
-            for (int l = 2; l < 4; ++l) {
+            for (short l = 2; l < 4; ++l) {
                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
                 const uint8_t signs = ssigns[(q2[l] >> 9)];
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
@@ -7424,10 +7817,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = sum_all * 0.25f;
         }
     }
 }
@@ -7443,10 +7836,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template 
+template
 void kernel_mul_mv_iq3_xxs_f32_impl(
         args_t args,
         device const char * src0,
@@ -7462,7 +7855,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7474,7 +7867,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     device const float         * y = (device const float         *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -7495,7 +7888,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -7507,17 +7900,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
         device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const uint32_t aux32 = gas[0] | (gas[1] << 16);
             const float d = db * (0.5f + (aux32 >> 28));
 
             float2 sum = {0};
-            for (int l = 0; l < 4; ++l) {
+            for (short l = 0; l < 4; ++l) {
                 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
                 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
                 const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 4; ++j) {
+                for (short j = 0; j < 4; ++j) {
                     sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
                     sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
                 }
@@ -7534,10 +7927,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.5f;
+            dst_f32[first_row + row] = sum_all * 0.5f;
         }
     }
 }
@@ -7553,10 +7946,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_iq3_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -7572,7 +7965,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7584,7 +7977,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -7601,8 +7994,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -7616,18 +8008,17 @@ void kernel_mul_mv_iq3_s_f32_impl(
         device const uint8_t * signs = xr->signs + 4 * ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
 
             float2 sum = {0};
-            for (int l = 0; l < 4; ++l) {
+            for (short l = 0; l < 4; ++l) {
                 const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
                 const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
                 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
                 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
-                for (int j = 0; j < 4; ++j) {
+                for (short j = 0; j < 4; ++j) {
                     sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
                     sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
                 }
@@ -7646,10 +8037,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
@@ -7665,10 +8056,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template 
+template
 void kernel_mul_mv_iq2_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -7684,7 +8075,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7696,7 +8087,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -7708,13 +8099,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
     //    threadgroup_barrier(mem_flags::mem_threadgroup);
     //}
 
-    const int ix = tiisg;
+    const short ix = tiisg;
 
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -7728,19 +8118,18 @@ void kernel_mul_mv_iq2_s_f32_impl(
         device const uint8_t * signs = qs + QK_K/8;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const float d1 = db * (0.5f + (sc[0] & 0xf));
             const float d2 = db * (0.5f + (sc[0] >>  4));
 
             float2 sum = {0};
-            for (int l = 0; l < 2; ++l) {
+            for (short l = 0; l < 2; ++l) {
                 //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
                 //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
                 constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
                 constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum[0] += yl[8*l + j +  0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
                     sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
                 }
@@ -7759,10 +8148,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = sum_all * 0.25f;
         }
     }
 }
@@ -7778,10 +8167,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_iq1_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -7797,7 +8186,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7809,18 +8198,17 @@ void kernel_mul_mv_iq1_s_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
-    const int ix = tiisg;
+    const short ix = tiisg;
 
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
         float sumy = 0;
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
             sumy += yl[i];
         }
@@ -7833,15 +8221,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
         device const uint16_t * qh = xr->qh + ib;
         device const half     * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
             constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
             constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
             constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
 
             float sum = 0;
-            for (int j = 0; j < 4; ++j) {
+            for (short j = 0; j < 4; ++j) {
                 sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
                      + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
                      + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -7859,15 +8246,28 @@ void kernel_mul_mv_iq1_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-template 
+[[host_name("kernel_mul_mv_iq1_s_f32")]]
+kernel void kernel_mul_mv_iq1_s_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
 void kernel_mul_mv_iq1_m_f32_impl(
         args_t args,
         device const char * src0,
@@ -7879,11 +8279,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
         ushort sgitg) {
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7895,20 +8296,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
-    const int ix = tiisg;
+    const short ix = tiisg;
 
     device const float * y4 = y + 32 * ix;
 
     iq1m_scale_t scale;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
         float4 sumy = {0.f};
-        for (int i = 0; i < 8; ++i) {
+        for (short i = 0; i < 8; ++i) {
             yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
             yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
             yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
@@ -7923,7 +8323,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
         device const uint8_t  * qh = xr->qh + 2 * ib;
         device const uint16_t * sc = (device const uint16_t *)xr->scales;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
 
             constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -7932,7 +8332,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
             constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
 
             float2 sum = {0.f};
-            for (int j = 0; j < 4; ++j) {
+            for (short j = 0; j < 4; ++j) {
                 sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
                         + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
                 sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -7954,15 +8354,28 @@ void kernel_mul_mv_iq1_m_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-template
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
 void kernel_mul_mv_iq4_nl_f32_impl(
         args_t args,
         device const char * src0,
@@ -7975,10 +8388,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     threadgroup float * shmem_f32 = (threadgroup float *) shmem;
     const int nb = args.ne00/QK4_NL;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    const int first_row = (r0 * 2 + sgitg) * 2;
+
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7989,14 +8404,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
     device const float        * y = (device const float        *) (src1 + offset1);
 
-    const int ix = tiisg/2;  // 0...15
-    const int it = tiisg%2;  // 0 or 1
+    const short ix = tiisg/2;  // 0...15
+    const short it = tiisg%2;  // 0 or 1
 
     shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
-    float sumf[2]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     device const float * yb = y + ix * QK4_NL + it * 8;
 
@@ -8006,12 +8421,13 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     float4 qf1, qf2;
 
     for (int ib = ix; ib < nb; ib += 16) {
-
         device const float4 * y4 = (device const float4 *)yb;
-        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
-
-        for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+        yl[0] = y4[0];
+        yl[1] = y4[4];
+        yl[2] = y4[1];
+        yl[3] = y4[5];
 
+        for (short row = 0; row < nr0; row++) {
             device const block_iq4_nl & xb = x[row*nb + ib];
             device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
 
@@ -8036,7 +8452,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
             acc1 += acc2;
 
             sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
         }
 
         yb += 16 * QK4_NL;
@@ -8044,15 +8459,29 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-template
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
+kernel void kernel_mul_mv_iq4_nl_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
 void kernel_mul_mv_iq4_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -8068,7 +8497,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    const int first_row = (r0 * 2 + sgitg) * 2;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -8079,16 +8508,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
     device const float        * y = (device const float        *) (src1 + offset1);
 
-    const int ix = tiisg/16;  // 0 or 1
-    const int it = tiisg%16;  // 0...15
-    const int ib = it/2;
-    const int il = it%2;
+    const short ix = tiisg/16;  // 0 or 1
+    const short it = tiisg%16;  // 0...15
+    const short ib = it/2;
+    const short il = it%2;
 
     shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
-    float sumf[2]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
 
@@ -8099,9 +8528,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     for (int ibl = ix; ibl < nb; ibl += 2) {
         device const float4 * y4 = (device const float4 *)yb;
-        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+        yl[0] = y4[0];
+        yl[1] = y4[4];
+        yl[2] = y4[1];
+        yl[3] = y4[5];
 
-        for (int row = 0; row < 2; ++row) {
+        for (short row = 0; row < nr0; ++row) {
             device const block_iq4_xs & xb = x[row*nb + ibl];
             device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
 
@@ -8125,7 +8557,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
             const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
             sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
         }
 
         yb += 2 * QK_K;
@@ -8133,54 +8564,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-[[host_name("kernel_mul_mv_iq1_s_f32")]]
-kernel void kernel_mul_mv_iq1_s_f32(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq1_m_f32")]]
-kernel void kernel_mul_mv_iq1_m_f32(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq4_nl_f32")]]
-kernel void kernel_mul_mv_iq4_nl_f32(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        threadgroup  char * shmem [[threadgroup(0)]],
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
-}
-
 [[host_name("kernel_mul_mv_iq4_xs_f32")]]
 kernel void kernel_mul_mv_iq4_xs_f32(
         constant ggml_metal_kargs_mul_mv & args,
@@ -8192,7 +8583,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 template
@@ -8200,28 +8591,21 @@ kernel void kernel_get_rows_q(
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
 
     const int64_t i02 = i11;
 
-    for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+    for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
         float4x4 temp;
-        dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
-        *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+        dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
+        *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
     }
 }
 
@@ -8230,27 +8614,20 @@ kernel void kernel_get_rows_f(
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
 
     const int64_t i02 = i11;
 
-    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((      device float *) ((      device char *)  dst + i11*nb2  + i10*nb1))[ind] =
-        ((const device T     *) ((const device char *) src0 + i02*nb02 +  r*nb01))[ind];
+    for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
+        ((      device float *) ((      device char *)  dst + i11*args.nb2  + i10*args.nb1))[ind] =
+        ((const device T     *) ((const device char *) src0 + i02*args.nb02 +  r*args.nb01))[ind];
     }
 }
 
@@ -8258,27 +8635,20 @@ kernel void kernel_get_rows_i32(
         device const  void * src0,
         device const  void * src1,
         device     int32_t * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
 
     const int64_t i02 = i11;
 
-    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((      device int32_t *) ((      device char *) dst  + i11*nb2 + i10*nb1))[ind] =
-        ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+    for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
+        ((      device int32_t *) ((      device char *) dst  + i11*args.nb2 + i10*args.nb1))[ind] =
+        ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
     }
 }
 
@@ -8857,121 +9227,103 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
 #endif
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
 
 kernel void kernel_pool_2d_max_f32(
         device  const float * src0,
         device        float * dst,
-        constant    int32_t & k0,
-        constant    int32_t & k1,
-        constant    int32_t & s0,
-        constant    int32_t & s1,
-        constant    int32_t & p0,
-        constant    int32_t & p1,
-        constant    int64_t & IH,
-        constant    int64_t & IW,
-        constant    int64_t & OH,
-        constant    int64_t & OW,
-        constant    int64_t & parallel_elements,
+        constant    ggml_metal_kargs_pool_2d & args,
         uint        gid[[thread_position_in_grid]]) {
 
-    if (gid >= parallel_elements) {
+    if (gid >= args.parallel_elements) {
         return;
     }
 
     const int idx = gid;
-    const int I_HW = IH * IW;
-    const int O_HW = OH * OW;
+    const int I_HW = args.IH * args.IW;
+    const int O_HW = args.OH * args.OW;
     const int nc = idx / O_HW;
-    const int cur_oh = idx % O_HW / OW;
-    const int cur_ow = idx % O_HW % OW;
+    const int cur_oh = idx % O_HW / args.OW;
+    const int cur_ow = idx % O_HW % args.OW;
 
     device const float * i_ptr = src0 + nc * I_HW;
     device       float * o_ptr = dst  + nc * O_HW;
 
-    const int start_h = cur_oh * s1 - p1;
+    const int start_h = cur_oh * args.s1 - args.p1;
     const int bh = MAX(0,  start_h);
-    const int eh = MIN(IH, start_h + k1);
-    const int start_w = cur_ow * s0 - p0;
+    const int eh = MIN(args.IH, start_h + args.k1);
+    const int start_w = cur_ow * args.s0 - args.p0;
     const int bw = MAX(0,  start_w);
-    const int ew = MIN(IW, start_w + k0);
+    const int ew = MIN(args.IW, start_w + args.k0);
 
     float res = -INFINITY;
 
     for (int i = bh; i < eh; i += 1) {
         for (int j = bw; j < ew; j += 1) {
-            res = MAX(res, i_ptr[i * IW + j]);
+            res = MAX(res, i_ptr[i * args.IW + j]);
         }
     }
 
-    o_ptr[cur_oh * OW + cur_ow] = res;
+    o_ptr[cur_oh * args.OW + cur_ow] = res;
 }
 
 kernel void kernel_pool_2d_avg_f32(
         device  const float * src0,
         device        float * dst,
-        constant    int32_t & k0,
-        constant    int32_t & k1,
-        constant    int32_t & s0,
-        constant    int32_t & s1,
-        constant    int32_t & p0,
-        constant    int32_t & p1,
-        constant    int64_t & IH,
-        constant    int64_t & IW,
-        constant    int64_t & OH,
-        constant    int64_t & OW,
-        constant    int64_t & parallel_elements,
+        constant    ggml_metal_kargs_pool_2d & args,
         uint        gid[[thread_position_in_grid]]) {
 
-    if (gid >= parallel_elements) {
+    if (gid >= args.parallel_elements) {
         return;
     }
 
     const int idx = gid;
-    const int I_HW = IH * IW;
-    const int O_HW = OH * OW;
+    const int I_HW = args.IH * args.IW;
+    const int O_HW = args.OH * args.OW;
     const int nc = idx / O_HW;
-    const int cur_oh = idx % O_HW / OW;
-    const int cur_ow = idx % O_HW % OW;
+    const int cur_oh = idx % O_HW / args.OW;
+    const int cur_ow = idx % O_HW % args.OW;
 
     device const float * i_ptr = src0 + nc * I_HW;
     device       float * o_ptr = dst  + nc * O_HW;
 
-    const int start_h = cur_oh * s1 - p1;
+    const int start_h = cur_oh * args.s1 - args.p1;
     const int bh = MAX(0,  start_h);
-    const int eh = MIN(IH, start_h + k1);
-    const int start_w = cur_ow * s0 - p0;
+    const int eh = MIN(args.IH, start_h + args.k1);
+    const int start_w = cur_ow * args.s0 - args.p0;
     const int bw = MAX(0,  start_w);
-    const int ew = MIN(IW, start_w + k0);
+    const int ew = MIN(args.IW, start_w + args.k0);
     // const float scale = 1. / ((eh - bh) * (ew - bw));
-    const float scale = 1. / (k0 * k1);
+    const float scale = 1. / (args.k0 * args.k1);
 
     float res = 0;
 
     for (int i = bh; i < eh; i += 1) {
         for (int j = bw; j < ew; j += 1) {
-            float cur = i_ptr[i * IW + j];
+            float cur = i_ptr[i * args.IW + j];
             res += cur * scale;
         }
     }
 
-    o_ptr[cur_oh * OW + cur_ow] = res;
+    o_ptr[cur_oh * args.OW + cur_ow] = res;
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
index e3dc25f16..8721b272d 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -1,6 +1,70 @@
 #ifndef GGML_METAL_IMPL
 #define GGML_METAL_IMPL
 
+// kernel parameters for mat-vec threadgroups
+//
+// N_R0: number of src0 rows to process per simdgroup
+// N_SG: number of simdgroups per threadgroup
+//
+// TODO: for optimal performance, become function of the device and work size
+
+#define N_R0_Q4_0 4
+#define N_SG_Q4_0 2
+
+#define N_R0_Q4_1 4
+#define N_SG_Q4_1 2
+
+#define N_R0_Q5_0 4
+#define N_SG_Q5_0 2
+
+#define N_R0_Q5_1 4
+#define N_SG_Q5_1 2
+
+#define N_R0_Q8_0 4
+#define N_SG_Q8_0 2
+
+#define N_R0_Q2_K 4
+#define N_SG_Q2_K 2
+
+#define N_R0_Q3_K 2
+#define N_SG_Q3_K 2
+
+#define N_R0_Q4_K 4
+#define N_SG_Q4_K 2
+
+#define N_R0_Q5_K 2
+#define N_SG_Q5_K 2
+
+#define N_R0_Q6_K 1
+#define N_SG_Q6_K 2
+
+#define N_R0_IQ1_S 4
+#define N_SG_IQ1_S 2
+
+#define N_R0_IQ1_M 4
+#define N_SG_IQ1_M 2
+
+#define N_R0_IQ2_XXS 4
+#define N_SG_IQ2_XXS 2
+
+#define N_R0_IQ2_XS 4
+#define N_SG_IQ2_XS 2
+
+#define N_R0_IQ2_S 4
+#define N_SG_IQ2_S 2
+
+#define N_R0_IQ3_XXS 4
+#define N_SG_IQ3_XXS 2
+
+#define N_R0_IQ3_S 4
+#define N_SG_IQ3_S 2
+
+#define N_R0_IQ4_NL 2
+#define N_SG_IQ4_NL 2
+
+#define N_R0_IQ4_XS 2
+#define N_SG_IQ4_XS 2
+
 // kernel argument structs
 //
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -155,9 +219,12 @@ typedef struct {
     int32_t  ne11;
     int32_t  ne_12_2; // assume K and V are same shape
     int32_t  ne_12_3;
-    uint64_t nb_12_1;
-    uint64_t nb_12_2;
-    uint64_t nb_12_3;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb23;
     uint64_t nb31;
     int32_t  ne1;
     int32_t  ne2;
@@ -285,4 +352,246 @@ typedef struct {
     float    eps;
 } ggml_metal_kargs_rms_norm;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_l2_norm;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  n_groups;
+    float    eps;
+} ggml_metal_kargs_group_norm;
+
+typedef struct {
+    int32_t  IC;
+    int32_t  IL;
+    int32_t  K;
+    int32_t  s0;
+    uint64_t nb0;
+    uint64_t nb1;
+} ggml_metal_kargs_conv_transpose_1d;
+
+typedef struct {
+    uint64_t  ofs0;
+    uint64_t  ofs1;
+    int32_t  IW;
+    int32_t  IH;
+    int32_t  CHW;
+    int32_t  s0;
+    int32_t  s1;
+    int32_t  p0;
+    int32_t  p1;
+    int32_t  d0;
+    int32_t  d1;
+    int32_t  N;
+    int32_t  KH;
+    int32_t  KW;
+    int32_t  KHW; // KH * KW, pre-computed on CPU to save GPU resources
+} ggml_metal_kargs_im2col;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne10;
+    int64_t  ne11;
+    int64_t  ne12;
+    int64_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_sum_rows;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    float    scale;
+    float    max_bias;
+    float    m0;
+    float    m1;
+    uint32_t n_head_log2;
+} ggml_metal_kargs_soft_max;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int      n_past;
+} ggml_metal_kargs_diag_mask_inf;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int64_t  ne10;
+    int64_t  ne11;
+    uint64_t nb10;
+    uint64_t nb11;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+} ggml_metal_kargs_ssm_conv;
+
+typedef struct {
+    int64_t  d_state;
+    int64_t  d_inner;
+    int64_t  n_seq_tokens;
+    int64_t  n_seqs;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb20;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb30;
+    uint64_t nb31;
+    uint64_t nb40;
+    uint64_t nb41;
+    uint64_t nb42;
+    uint64_t nb50;
+    uint64_t nb51;
+    uint64_t nb52;
+} ggml_metal_kargs_ssm_scan;
+
+typedef struct {
+    int64_t  ne00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int64_t  ne10;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb1;
+    uint64_t nb2;
+} ggml_metal_kargs_get_rows;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    float    sf0;
+    float    sf1;
+    float    sf2;
+    float    sf3;
+} ggml_metal_kargs_upscale;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_pad;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  p0;
+    int32_t  p1;
+} ggml_metal_kargs_pad_reflect_1d;
+
+typedef struct {
+    uint64_t nb1;
+    int      dim;
+    int      max_period;
+} ggml_metal_kargs_timestep_embedding;
+
+typedef struct {
+    float    slope;
+} ggml_metal_kargs_leaky_relu;
+
+typedef struct {
+    int64_t  ncols;
+    int64_t  ncols_pad;
+} ggml_metal_kargs_argsort;
+
+typedef struct {
+    int64_t  ne0;
+    float    start;
+    float    step;
+} ggml_metal_kargs_arange;
+
+typedef struct {
+    int32_t  k0;
+    int32_t  k1;
+    int32_t  s0;
+    int32_t  s1;
+    int32_t  p0;
+    int32_t  p1;
+    int64_t  IH;
+    int64_t  IW;
+    int64_t  OH;
+    int64_t  OW;
+    int64_t  parallel_elements;
+} ggml_metal_kargs_pool_2d;
+
 #endif // GGML_METAL_IMPL
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
index d8422f1b7..fea505214 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
@@ -46,6 +46,7 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
 static struct ggml_backend_metal_device_context {
     id mtl_device;
     int           mtl_device_ref_count;
+    id mtl_library;
 
     bool has_simdgroup_reduction;
     bool has_simdgroup_mm;
@@ -57,6 +58,7 @@ static struct ggml_backend_metal_device_context {
 } g_ggml_ctx_dev_main = {
     /*.mtl_device              =*/ nil,
     /*.mtl_device_ref_count    =*/ 0,
+    /*.mtl_library             =*/ nil,
     /*.has_simdgroup_reduction =*/ false,
     /*.has_simdgroup_mm        =*/ false,
     /*.has_residency_sets      =*/ false,
@@ -108,6 +110,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     ctx->mtl_device_ref_count--;
 
     if (ctx->mtl_device_ref_count == 0) {
+        if (ctx->mtl_library) {
+            [ctx->mtl_library release];
+            ctx->mtl_library = nil;
+        }
+
         if (ctx->mtl_device) {
             [ctx->mtl_device release];
             ctx->mtl_device = nil;
@@ -177,10 +184,13 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
+    GGML_METAL_KERNEL_TYPE_L2_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
     GGML_METAL_KERNEL_TYPE_NORM,
     GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
     GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
+    GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
+    GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -342,42 +352,56 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
@@ -386,6 +410,20 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
@@ -469,11 +507,13 @@ struct ggml_backend_metal_context {
 //       for now it is easier to work in a separate file
 // static NSString * const msl_library_source = @"see metal.metal";
 
+#if !GGML_METAL_EMBED_LIBRARY
 // Here to assist with NSBundle Path Hack
 @interface GGMLMetalClass : NSObject
 @end
 @implementation GGMLMetalClass
 @end
+#endif
 
 static void * ggml_metal_host_malloc(size_t n) {
     void * data = NULL;
@@ -495,6 +535,139 @@ static void * ggml_metal_host_malloc(size_t n) {
     return data;
 }
 
+// load library
+//
+// - first check if the library is embedded
+// - then check if the library is in the bundle
+// - if not found, load the source and compile it
+// - if that fails, return NULL
+static id ggml_metal_load_library(id device, bool use_bfloat) {
+    id metal_library = nil;
+    NSError * error = nil;
+    NSString * src = nil;
+
+#if GGML_METAL_EMBED_LIBRARY
+    GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
+
+    extern const char ggml_metallib_start[];
+    extern const char ggml_metallib_end[];
+
+    src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
+
+#else
+
+#ifdef SWIFT_PACKAGE
+    NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
+#else
+    NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
+#endif
+
+    NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
+    if (path_lib == nil) {
+        // Try to find the resource in the directory where the current binary located.
+        NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
+        NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
+        NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
+        if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
+            GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
+            NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
+            if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
+                // Optionally, if this is a symlink, try to resolve it.
+                default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
+                if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
+                    // It is a relative path, adding the binary directory as directory prefix.
+                    default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
+                }
+                if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
+                    // Link to the resource could not be resolved.
+                    default_metallib_path = nil;
+                } else {
+                    GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
+                }
+            }
+        } else {
+            // The resource couldn't be found in the binary's directory.
+            default_metallib_path = nil;
+        }
+        path_lib = default_metallib_path;
+    }
+
+    if (path_lib != nil) {
+        // pre-compiled library found
+        NSURL * libURL = [NSURL fileURLWithPath:path_lib];
+        GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
+
+        metal_library = [device newLibraryWithURL:libURL error:&error];
+        if (error) {
+            GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+            return NULL;
+        }
+    } else {
+        GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
+
+        NSString * path_source;
+        NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
+
+        GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
+
+        if (path_resource) {
+            path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
+        } else {
+            path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+        }
+
+        if (path_source == nil) {
+            GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
+            path_source = @"ggml-metal.metal";
+        }
+
+        GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
+
+        src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
+        if (error) {
+            GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+            return NULL;
+        }
+    }
+#endif
+
+    if (!metal_library) {
+        @autoreleasepool {
+            // dictionary of preprocessor macros
+            NSMutableDictionary * prep = [NSMutableDictionary dictionary];
+
+            if (use_bfloat) {
+                [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
+            }
+
+#if GGML_METAL_EMBED_LIBRARY
+            [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
+#endif
+
+            MTLCompileOptions * options = [MTLCompileOptions new];
+            options.preprocessorMacros = prep;
+
+            //[options setFastMathEnabled:false];
+
+            metal_library = [device newLibraryWithSource:src options:options error:&error];
+            if (error) {
+                GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+                return NULL;
+            }
+
+#if !__has_feature(objc_arc)
+            [options release];
+#endif
+        }
+    }
+
+#if GGML_METAL_EMBED_LIBRARY
+    [src release];
+#endif // GGML_METAL_EMBED_LIBRARY
+
+    return metal_library;
+}
+
 static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
     GGML_LOG_INFO("%s: allocating\n", __func__);
 
@@ -522,137 +695,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 
     ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
-    id metal_library;
-
     // load library
-    //
-    // - first check if the library is embedded
-    // - then check if the library is in the bundle
-    // - if not found, load the source and compile it
-    // - if that fails, return NULL
-    {
-        NSBundle * bundle = nil;
-#ifdef SWIFT_PACKAGE
-        bundle = SWIFTPM_MODULE_BUNDLE;
-#else
-        bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
-#endif
-
-        NSError * error = nil;
-
-#if GGML_METAL_EMBED_LIBRARY
-        const bool try_metallib = false;
-#else
-        const bool try_metallib = true;
-#endif
-
-        NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
-        if (path_lib == nil) {
-            // Try to find the resource in the directory where the current binary located.
-            NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
-            NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
-            NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
-            if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
-                GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
-                NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
-                if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
-                    // Optionally, if this is a symlink, try to resolve it.
-                    default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
-                    if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
-                        // It is a relative path, adding the binary directory as directory prefix.
-                        default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
-                    }
-                    if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
-                        // Link to the resource could not be resolved.
-                        default_metallib_path = nil;
-                    } else {
-                        GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
-                    }
-                }
-            } else {
-                // The resource couldn't be found in the binary's directory.
-                default_metallib_path = nil;
-            }
-            path_lib = default_metallib_path;
-        }
-
-        if (try_metallib && path_lib != nil) {
-            // pre-compiled library found
-            NSURL * libURL = [NSURL fileURLWithPath:path_lib];
-            GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
-
-            metal_library = [device newLibraryWithURL:libURL error:&error];
-            if (error) {
-                GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
-                return NULL;
-            }
-        } else {
-#if GGML_METAL_EMBED_LIBRARY
-            GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
-
-            extern const char ggml_metallib_start[];
-            extern const char ggml_metallib_end[];
-
-            NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
-#else
-            GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
-
-            NSString * path_source;
-            NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
-
-            GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
-
-            if (path_resource) {
-                path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
-            } else {
-                path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
-            }
-
-            if (path_source == nil) {
-                GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
-                path_source = @"ggml-metal.metal";
-            }
-
-            GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
-
-            NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
-            if (error) {
-                GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
-                return NULL;
-            }
-#endif // GGML_METAL_EMBED_LIBRARY
-
-            @autoreleasepool {
-                // dictionary of preprocessor macros
-                NSMutableDictionary * prep = [NSMutableDictionary dictionary];
-
-                if (ctx_dev->use_bfloat) {
-                    [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
-                }
-
-#if GGML_METAL_EMBED_LIBRARY
-                [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
-#endif
-
-                MTLCompileOptions * options = [MTLCompileOptions new];
-                options.preprocessorMacros = prep;
-
-                //[options setFastMathEnabled:false];
-
-                metal_library = [device newLibraryWithSource:src options:options error:&error];
-                if (error) {
-                    GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
-                    return NULL;
-                }
-
-#if !__has_feature(objc_arc)
-                [options release];
-#endif
-            }
-#if GGML_METAL_EMBED_LIBRARY
-            [src release];
-#endif // GGML_METAL_EMBED_LIBRARY
-        }
+    if (ctx_dev->mtl_library == nil) {
+        ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
+    }
+    id metal_library = ctx_dev->mtl_library;
+    if (metal_library == nil) {
+        GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
+        return NULL;
     }
 
     // print MTL GPU family:
@@ -726,7 +776,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
             [metal_function release]; \
             if (error) { \
                 GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
-                [metal_library release]; \
                 return NULL; \
             } \
         } else { \
@@ -739,316 +788,345 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 
         // simd_sum and simd_max requires MTLGPUFamilyApple7
 
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                           add,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                       add_row,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,                           sub,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW,                       sub_row,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                           mul,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                       div_row,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                    repeat_f32,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                    repeat_f16,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,                    repeat_i32,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16,                    repeat_i16,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                         scale,                          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                       scale_4,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                         clamp,                          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                          tanh,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                          relu,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                       sigmoid,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                          gelu,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                        gelu_4,                         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                    gelu_quick,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                  gelu_quick_4,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                          silu,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                        silu_4,                         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU,                           elu,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,                 diag_mask_inf,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,               diag_mask_inf_8,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,                  get_rows_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,                  get_rows_f16,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,                 get_rows_bf16,                  use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,                 get_rows_q4_0,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,                 get_rows_q4_1,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                 get_rows_q5_0,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,                 get_rows_q5_1,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,                 get_rows_q8_0,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,                 get_rows_q2_K,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,                 get_rows_q3_K,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,                 get_rows_q4_K,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,                 get_rows_q5_K,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,                 get_rows_q6_K,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,              get_rows_iq2_xxs,               true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,               get_rows_iq2_xs,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,              get_rows_iq3_xxs,               true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,                get_rows_iq3_s,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,                get_rows_iq2_s,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,                get_rows_iq1_s,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,                get_rows_iq1_m,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,               get_rows_iq4_nl,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,               get_rows_iq4_xs,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                  get_rows_i32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                  ssm_conv_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                  ssm_scan_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,               mul_mv_bf16_f32,                has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,          mul_mv_bf16_f32_1row,           has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,            mul_mv_bf16_f32_l4,             has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,              mul_mv_bf16_bf16,               has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,       mul_mv_ext_f16_f32_r1_2,        has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,       mul_mv_ext_f16_f32_r1_3,        has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,       mul_mv_ext_f16_f32_r1_4,        has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,       mul_mv_ext_f16_f32_r1_5,        has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,      mul_mv_ext_q4_0_f32_r1_2,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,      mul_mv_ext_q4_0_f32_r1_3,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,      mul_mv_ext_q4_0_f32_r1_4,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,      mul_mv_ext_q4_0_f32_r1_5,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,      mul_mv_ext_q4_1_f32_r1_2,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,      mul_mv_ext_q4_1_f32_r1_3,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,      mul_mv_ext_q4_1_f32_r1_4,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,      mul_mv_ext_q4_1_f32_r1_5,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,      mul_mv_ext_q5_0_f32_r1_2,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,      mul_mv_ext_q5_0_f32_r1_3,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,      mul_mv_ext_q5_0_f32_r1_4,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,      mul_mv_ext_q5_0_f32_r1_5,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,      mul_mv_ext_q5_1_f32_r1_2,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,      mul_mv_ext_q5_1_f32_r1_3,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,      mul_mv_ext_q5_1_f32_r1_4,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,      mul_mv_ext_q5_1_f32_r1_5,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,      mul_mv_ext_q8_0_f32_r1_2,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,      mul_mv_ext_q8_0_f32_r1_3,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,      mul_mv_ext_q8_0_f32_r1_4,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,      mul_mv_ext_q8_0_f32_r1_5,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,      mul_mv_ext_q4_K_f32_r1_2,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,      mul_mv_ext_q4_K_f32_r1_3,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,      mul_mv_ext_q4_K_f32_r1_4,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,      mul_mv_ext_q4_K_f32_r1_5,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,      mul_mv_ext_q5_K_f32_r1_2,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,      mul_mv_ext_q5_K_f32_r1_3,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,      mul_mv_ext_q5_K_f32_r1_4,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,      mul_mv_ext_q5_K_f32_r1_5,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,      mul_mv_ext_q6_K_f32_r1_2,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,      mul_mv_ext_q6_K_f32_r1_3,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,      mul_mv_ext_q6_K_f32_r1_4,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,      mul_mv_ext_q6_K_f32_r1_5,       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,    mul_mv_ext_iq4_nl_f32_r1_2,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,    mul_mv_ext_iq4_nl_f32_r1_3,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,    mul_mv_ext_iq4_nl_f32_r1_4,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,    mul_mv_ext_iq4_nl_f32_r1_5,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              has_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         has_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           has_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,            mul_mv_id_bf16_f32,             has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,               mul_mm_bf16_f32,                has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,            mul_mm_id_bf16_f32,             has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                 rope_norm_f32,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                im2col_ext_f16,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,                im2col_ext_f32,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,     conv_transpose_1d_f32_f32,      true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,     conv_transpose_1d_f16_f32,      true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,            pad_reflect_1d_f32,             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32,                     unpad_f32,                      true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,          argsort_f32_i32_desc,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                leaky_relu_f32,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,       flash_attn_ext_bf16_h64,        has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,       flash_attn_ext_bf16_h80,        has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,       flash_attn_ext_bf16_h96,        has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,      flash_attn_ext_bf16_h112,       has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,      flash_attn_ext_bf16_h128,       has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,      flash_attn_ext_bf16_h256,       has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,      flash_attn_ext_q4_0_h112,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,      flash_attn_ext_q4_0_h128,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,      flash_attn_ext_q4_0_h256,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,       flash_attn_ext_q4_1_h64,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,       flash_attn_ext_q4_1_h80,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,       flash_attn_ext_q4_1_h96,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,      flash_attn_ext_q4_1_h112,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,      flash_attn_ext_q4_1_h128,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,      flash_attn_ext_q4_1_h256,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,       flash_attn_ext_q5_0_h64,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,       flash_attn_ext_q5_0_h80,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,       flash_attn_ext_q5_0_h96,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,      flash_attn_ext_q5_0_h112,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,      flash_attn_ext_q5_0_h128,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,      flash_attn_ext_q5_0_h256,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,       flash_attn_ext_q5_1_h64,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,       flash_attn_ext_q5_1_h80,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,       flash_attn_ext_q5_1_h96,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,      flash_attn_ext_q5_1_h112,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,      flash_attn_ext_q5_1_h128,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,      flash_attn_ext_q5_1_h256,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,       flash_attn_ext_q8_0_h64,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,       flash_attn_ext_q8_0_h80,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,       flash_attn_ext_q8_0_h96,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,      flash_attn_ext_q8_0_h112,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,  flash_attn_ext_vec_bf16_h128,   has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,  flash_attn_ext_vec_bf16_h256,   has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,  flash_attn_ext_vec_q5_1_h256,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,  flash_attn_ext_vec_q8_0_h256,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32,                       set_f32,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32,                       set_i32,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,                  cpy_f32_bf16,                   use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,                  cpy_bf16_f32,                   use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,                 cpy_bf16_bf16,                  use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                  cpy_f32_q4_0,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                  cpy_f32_q4_1,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,                  cpy_q4_0_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,                  cpy_q4_0_f16,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,                  cpy_q4_1_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,                  cpy_q4_1_f16,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,                  cpy_q5_0_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,                  cpy_q5_0_f16,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,                  cpy_q5_1_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,                  cpy_q5_1_f16,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,                  cpy_q8_0_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,                  cpy_q8_0_f16,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                           neg,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                        argmax,                         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,               pool_2d_avg_f32,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,               pool_2d_max_f32,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                             add,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                         add_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,                             sub,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW,                         sub_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                             mul,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                         mul_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                             div,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                         div_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                      repeat_f32,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                      repeat_f16,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,                      repeat_i32,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16,                      repeat_i16,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                           scale,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                         scale_4,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                           clamp,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                            tanh,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                            relu,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                         sigmoid,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                            gelu,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                          gelu_4,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                      gelu_quick,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                    gelu_quick_4,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                            silu,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                          silu_4,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU,                             elu,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                    soft_max_f16,                    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                  soft_max_f16_4,                  has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                    soft_max_f32,                    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                  soft_max_f32_4,                  has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,                   diag_mask_inf,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,                 diag_mask_inf_8,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,                    get_rows_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,                    get_rows_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,                   get_rows_bf16,                   use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,                   get_rows_q4_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,                   get_rows_q4_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                   get_rows_q5_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,                   get_rows_q5_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,                   get_rows_q8_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,                   get_rows_q2_K,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,                   get_rows_q3_K,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,                   get_rows_q4_K,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,                   get_rows_q5_K,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,                   get_rows_q6_K,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,                get_rows_iq2_xxs,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,                 get_rows_iq2_xs,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,                get_rows_iq3_xxs,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,                  get_rows_iq3_s,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,                  get_rows_iq2_s,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,                  get_rows_iq1_s,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,                  get_rows_iq1_m,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,                 get_rows_iq4_nl,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,                 get_rows_iq4_xs,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                    get_rows_i32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                        rms_norm,                        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                    ssm_conv_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                    ssm_scan_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,                   rwkv_wkv6_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,                   rwkv_wkv7_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                  mul_mv_f32_f32,                  has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,                 mul_mv_bf16_f32,                 has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,            mul_mv_bf16_f32_1row,            has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,              mul_mv_bf16_f32_l4,              has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,                mul_mv_bf16_bf16,                has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                  mul_mv_f16_f32,                  has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,             mul_mv_f16_f32_1row,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,               mul_mv_f16_f32_l4,               has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                  mul_mv_f16_f16,                  has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,                 mul_mv_q4_0_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,                 mul_mv_q4_1_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,                 mul_mv_q5_0_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,                 mul_mv_q5_1_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,                 mul_mv_q8_0_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,         mul_mv_ext_f16_f32_r1_2,         has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,         mul_mv_ext_f16_f32_r1_3,         has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,         mul_mv_ext_f16_f32_r1_4,         has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,         mul_mv_ext_f16_f32_r1_5,         has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,        mul_mv_ext_q4_0_f32_r1_2,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,        mul_mv_ext_q4_0_f32_r1_3,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,        mul_mv_ext_q4_0_f32_r1_4,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,        mul_mv_ext_q4_0_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,        mul_mv_ext_q4_1_f32_r1_2,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,        mul_mv_ext_q4_1_f32_r1_3,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,        mul_mv_ext_q4_1_f32_r1_4,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,        mul_mv_ext_q4_1_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,        mul_mv_ext_q5_0_f32_r1_2,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,        mul_mv_ext_q5_0_f32_r1_3,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,        mul_mv_ext_q5_0_f32_r1_4,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,        mul_mv_ext_q5_0_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,        mul_mv_ext_q5_1_f32_r1_2,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,        mul_mv_ext_q5_1_f32_r1_3,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,        mul_mv_ext_q5_1_f32_r1_4,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,        mul_mv_ext_q5_1_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,        mul_mv_ext_q8_0_f32_r1_2,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,        mul_mv_ext_q8_0_f32_r1_3,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,        mul_mv_ext_q8_0_f32_r1_4,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,        mul_mv_ext_q8_0_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,        mul_mv_ext_q4_K_f32_r1_2,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,        mul_mv_ext_q4_K_f32_r1_3,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,        mul_mv_ext_q4_K_f32_r1_4,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,        mul_mv_ext_q4_K_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,        mul_mv_ext_q5_K_f32_r1_2,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,        mul_mv_ext_q5_K_f32_r1_3,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,        mul_mv_ext_q5_K_f32_r1_4,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,        mul_mv_ext_q5_K_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,        mul_mv_ext_q6_K_f32_r1_2,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,        mul_mv_ext_q6_K_f32_r1_3,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,        mul_mv_ext_q6_K_f32_r1_4,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,        mul_mv_ext_q6_K_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,      mul_mv_ext_iq4_nl_f32_r1_2,      has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,      mul_mv_ext_iq4_nl_f32_r1_3,      has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,      mul_mv_ext_iq4_nl_f32_r1_4,      has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,      mul_mv_ext_iq4_nl_f32_r1_5,      has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,                 mul_mv_q2_K_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,                 mul_mv_q3_K_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,                 mul_mv_q4_K_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,                 mul_mv_q5_K_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,                 mul_mv_q6_K_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,              mul_mv_iq2_xxs_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,               mul_mv_iq2_xs_f32,               has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,              mul_mv_iq3_xxs_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,                mul_mv_iq3_s_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,                mul_mv_iq2_s_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,                mul_mv_iq1_s_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,                mul_mv_iq1_m_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,               mul_mv_iq4_nl_f32,               has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,               mul_mv_iq4_xs_f32,               has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,               mul_mv_id_f32_f32,               has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,               mul_mv_id_f16_f32,               has_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,          mul_mv_id_f16_f32_1row,          has_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,            mul_mv_id_f16_f32_l4,            has_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,               mul_mv_id_f16_f16,               has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,              mul_mv_id_bf16_f32,              has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,              mul_mv_id_q4_0_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,              mul_mv_id_q4_1_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,              mul_mv_id_q5_0_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,              mul_mv_id_q5_1_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,              mul_mv_id_q8_0_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,              mul_mv_id_q2_K_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,              mul_mv_id_q3_K_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,              mul_mv_id_q4_K_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,              mul_mv_id_q5_K_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,              mul_mv_id_q6_K_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,           mul_mv_id_iq2_xxs_f32,           has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,            mul_mv_id_iq2_xs_f32,            has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,           mul_mv_id_iq3_xxs_f32,           has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,             mul_mv_id_iq3_s_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,             mul_mv_id_iq2_s_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,             mul_mv_id_iq1_s_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,             mul_mv_id_iq1_m_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,            mul_mv_id_iq4_nl_f32,            has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,            mul_mv_id_iq4_xs_f32,            has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                  mul_mm_f32_f32,                  has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                  mul_mm_f16_f32,                  has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,                 mul_mm_bf16_f32,                 has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,                 mul_mm_q4_0_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,                 mul_mm_q4_1_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,                 mul_mm_q5_0_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,                 mul_mm_q5_1_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,                 mul_mm_q8_0_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,                 mul_mm_q2_K_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,                 mul_mm_q3_K_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,                 mul_mm_q4_K_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,                 mul_mm_q5_K_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,                 mul_mm_q6_K_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,              mul_mm_iq2_xxs_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,               mul_mm_iq2_xs_f32,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,              mul_mm_iq3_xxs_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,                mul_mm_iq3_s_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,                mul_mm_iq2_s_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,                mul_mm_iq1_s_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,                mul_mm_iq1_m_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,               mul_mm_iq4_nl_f32,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,               mul_mm_iq4_xs_f32,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,               mul_mm_id_f32_f32,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,               mul_mm_id_f16_f32,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,              mul_mm_id_bf16_f32,              has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,              mul_mm_id_q4_0_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,              mul_mm_id_q4_1_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,              mul_mm_id_q5_0_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,              mul_mm_id_q5_1_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,              mul_mm_id_q8_0_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,              mul_mm_id_q2_K_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,              mul_mm_id_q3_K_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,              mul_mm_id_q4_K_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,              mul_mm_id_q5_K_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,              mul_mm_id_q6_K_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,           mul_mm_id_iq2_xxs_f32,           has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,            mul_mm_id_iq2_xs_f32,            has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,           mul_mm_id_iq3_xxs_f32,           has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,             mul_mm_id_iq3_s_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,             mul_mm_id_iq2_s_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,             mul_mm_id_iq1_s_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,             mul_mm_id_iq1_m_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,            mul_mm_id_iq4_nl_f32,            has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,            mul_mm_id_iq4_xs_f32,            has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                   rope_norm_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                   rope_norm_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                   rope_neox_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                   rope_neox_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                      im2col_f16,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                      im2col_f32,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                  im2col_ext_f16,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,                  im2col_ext_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,       conv_transpose_1d_f32_f32,       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,       conv_transpose_1d_f16_f32,       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                     upscale_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                         pad_f32,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,              pad_reflect_1d_f32,              true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32,                       unpad_f32,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,          timestep_embedding_f32,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                      arange_f32,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,             argsort_f32_i32_asc,             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,            argsort_f32_i32_desc,            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                  leaky_relu_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,          flash_attn_ext_f16_h64,          has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,          flash_attn_ext_f16_h80,          has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,          flash_attn_ext_f16_h96,          has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,         flash_attn_ext_f16_h112,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,         flash_attn_ext_f16_h128,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,         flash_attn_ext_f16_h192,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,  flash_attn_ext_f16_hk192_hv128,  has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,         flash_attn_ext_f16_h256,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,         flash_attn_ext_bf16_h64,         has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,         flash_attn_ext_bf16_h80,         has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,         flash_attn_ext_bf16_h96,         has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,        flash_attn_ext_bf16_h112,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,        flash_attn_ext_bf16_h128,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,        flash_attn_ext_bf16_h192,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,        flash_attn_ext_bf16_h256,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,         flash_attn_ext_q4_0_h64,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,         flash_attn_ext_q4_0_h80,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,         flash_attn_ext_q4_0_h96,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,        flash_attn_ext_q4_0_h112,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,        flash_attn_ext_q4_0_h128,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,        flash_attn_ext_q4_0_h192,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,        flash_attn_ext_q4_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,         flash_attn_ext_q4_1_h64,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,         flash_attn_ext_q4_1_h80,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,         flash_attn_ext_q4_1_h96,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,        flash_attn_ext_q4_1_h112,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,        flash_attn_ext_q4_1_h128,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,        flash_attn_ext_q4_1_h192,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,        flash_attn_ext_q4_1_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,         flash_attn_ext_q5_0_h64,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,         flash_attn_ext_q5_0_h80,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,         flash_attn_ext_q5_0_h96,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,        flash_attn_ext_q5_0_h112,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,        flash_attn_ext_q5_0_h128,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,        flash_attn_ext_q5_0_h192,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,        flash_attn_ext_q5_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,         flash_attn_ext_q5_1_h64,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,         flash_attn_ext_q5_1_h80,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,         flash_attn_ext_q5_1_h96,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,        flash_attn_ext_q5_1_h112,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,        flash_attn_ext_q5_1_h128,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,        flash_attn_ext_q5_1_h192,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,        flash_attn_ext_q5_1_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,         flash_attn_ext_q8_0_h64,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,         flash_attn_ext_q8_0_h80,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,         flash_attn_ext_q8_0_h96,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,        flash_attn_ext_q8_0_h112,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,        flash_attn_ext_q8_0_h128,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,        flash_attn_ext_q8_0_h192,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,        flash_attn_ext_q8_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,     flash_attn_ext_vec_f16_h128,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,    flash_attn_ext_vec_bf16_h128,    has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,    flash_attn_ext_vec_q4_0_h128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,    flash_attn_ext_vec_q4_1_h128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,    flash_attn_ext_vec_q5_0_h128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,    flash_attn_ext_vec_q5_1_h128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,    flash_attn_ext_vec_q8_0_h128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,     flash_attn_ext_vec_f16_h192,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,    flash_attn_ext_vec_bf16_h192,    has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,    flash_attn_ext_vec_q4_0_h192,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,    flash_attn_ext_vec_q4_1_h192,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,    flash_attn_ext_vec_q5_0_h192,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,    flash_attn_ext_vec_q5_1_h192,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,    flash_attn_ext_vec_q8_0_h192,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,     flash_attn_ext_vec_f16_hk192_hv128,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,    flash_attn_ext_vec_bf16_hk192_hv128,    has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,    flash_attn_ext_vec_q4_0_hk192_hv128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,    flash_attn_ext_vec_q4_1_hk192_hv128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,    flash_attn_ext_vec_q5_0_hk192_hv128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,    flash_attn_ext_vec_q5_1_hk192_hv128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,    flash_attn_ext_vec_q8_0_hk192_hv128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,     flash_attn_ext_vec_f16_h256,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,    flash_attn_ext_vec_bf16_h256,    has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,    flash_attn_ext_vec_q4_0_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,    flash_attn_ext_vec_q4_1_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,    flash_attn_ext_vec_q5_0_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,    flash_attn_ext_vec_q5_1_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,    flash_attn_ext_vec_q8_0_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32,                         set_f32,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32,                         set_i32,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                     cpy_f32_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                     cpy_f32_f16,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,                    cpy_f32_bf16,                    use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                     cpy_f16_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                     cpy_f16_f16,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,                    cpy_bf16_f32,                    use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,                   cpy_bf16_bf16,                   use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                    cpy_f32_q8_0,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                    cpy_f32_q4_0,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                    cpy_f32_q4_1,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                    cpy_f32_q5_0,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                    cpy_f32_q5_1,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                  cpy_f32_iq4_nl,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,                    cpy_q4_0_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,                    cpy_q4_0_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,                    cpy_q4_1_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,                    cpy_q4_1_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,                    cpy_q5_0_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,                    cpy_q5_0_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,                    cpy_q5_1_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,                    cpy_q5_1_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,                    cpy_q8_0_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,                    cpy_q8_0_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                          concat,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                             sqr,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                            sqrt,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                             sin,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                             cos,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                             neg,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,                 pool_2d_max_f32,                 true);
     }
 
-    [metal_library release];
-
     return ctx;
 }
 
@@ -1205,7 +1283,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_NEG:
-                    return ggml_is_contiguous(op->src[0]);
+                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                 default:
                     return false;
             }
@@ -1215,26 +1293,32 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_TRANSPOSE:
         case GGML_OP_PERMUTE:
         case GGML_OP_CONCAT:
+            return true;
         case GGML_OP_ADD:
         case GGML_OP_SUB:
-        case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
+            return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ACC:
         case GGML_OP_REPEAT:
         case GGML_OP_SCALE:
-        case GGML_OP_CLAMP:
         case GGML_OP_CONV_TRANSPOSE_1D:
             return true;
+        case GGML_OP_CLAMP:
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_SIN:
         case GGML_OP_COS:
-            return ggml_is_contiguous(op->src[0]);
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_LOG:
+            return false; // TODO: implement
         case GGML_OP_SUM_ROWS:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
             return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_RMS_NORM:
+        case GGML_OP_L2_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ARGMAX:
             return true;
@@ -1255,23 +1339,32 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
             return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_1D:
             return false;
-        case GGML_OP_POOL_2D:
         case GGML_OP_UPSCALE:
+            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
+        case GGML_OP_POOL_2D:
         case GGML_OP_PAD:
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_UNPAD:
-        case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
         case GGML_OP_LEAKY_RELU:
+            return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ARANGE:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
+            if (op->src[0]->ne[0] == 32) {
+                // head size == 32 (e.g. bert-bge-small)
+                // TODO: not sure if it is worth adding kernels for this size
+                return false;
+            }
             if (op->src[1]->type != op->src[2]->type) {
                 return false;
             }
             return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
+        case GGML_OP_RWKV_WKV6:
+        case GGML_OP_RWKV_WKV7:
             return true;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
@@ -1956,34 +2049,38 @@ static void ggml_metal_encode_node(
 
                 id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+
+                ggml_metal_kargs_sum_rows args = {
+                   /*.ne00 =*/ ne00,
+                   /*.ne01 =*/ ne01,
+                   /*.ne02 =*/ ne02,
+                   /*.ne03 =*/ ne03,
+                   /*.nb00 =*/ nb00,
+                   /*.nb01 =*/ nb01,
+                   /*.nb02 =*/ nb02,
+                   /*.nb03 =*/ nb03,
+                   /*.ne10 =*/ ne10,
+                   /*.ne11 =*/ ne11,
+                   /*.ne12 =*/ ne12,
+                   /*.ne13 =*/ ne13,
+                   /*.nb10 =*/ nb10,
+                   /*.nb11 =*/ nb11,
+                   /*.nb12 =*/ nb12,
+                   /*.nb13 =*/ nb13,
+                   /*.ne0  =*/ ne0,
+                   /*.ne1  =*/ ne1,
+                   /*.ne2  =*/ ne2,
+                   /*.ne3  =*/ ne3,
+                   /*.nb0  =*/ nb0,
+                   /*.nb1  =*/ nb1,
+                   /*.nb2  =*/ nb2,
+                   /*.nb3  =*/ nb3,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:19];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:20];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:21];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:22];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:23];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:24];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:25];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -2032,8 +2129,17 @@ static void ggml_metal_encode_node(
                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-                // TODO: add ggml_metal_kargs struct
-                // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
+                ggml_metal_kargs_soft_max args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.scale =*/ scale,
+                    /*.max_bias =*/ max_bias,
+                    /*.m0 =*/ m0,
+                    /*.m1 =*/ m1,
+                    /*.n_head_log2 =*/ n_head_log2,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 if (id_src1) {
@@ -2042,14 +2148,7 @@ static void ggml_metal_encode_node(
                     [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
                 }
                 [encoder setBuffer:id_dst      offset:offs_dst            atIndex:2];
-                [encoder setBytes:&ne00        length:sizeof(ne00)        atIndex:3];
-                [encoder setBytes:&ne01        length:sizeof(ne01)        atIndex:4];
-                [encoder setBytes:&ne02        length:sizeof(ne02)        atIndex:5];
-                [encoder setBytes:&scale       length:sizeof(scale)       atIndex:6];
-                [encoder setBytes:&max_bias    length:sizeof(max_bias)    atIndex:7];
-                [encoder setBytes:&m0          length:sizeof(m0)          atIndex:8];
-                [encoder setBytes:&m1          length:sizeof(m1)          atIndex:9];
-                [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
+                [encoder setBytes:&args        length:sizeof(args)        atIndex:3];
 
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
@@ -2067,13 +2166,16 @@ static void ggml_metal_encode_node(
                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
                 }
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_diag_mask_inf args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.n_past =*/ n_past,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
+                [encoder setBytes:&args  length:sizeof(args) atIndex:2];
 
                 if (ne00%8 == 0) {
                     [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@@ -2092,27 +2194,30 @@ static void ggml_metal_encode_node(
 
                 id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_ssm_conv args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                 [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01    length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&nb00    length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&ne10    length:sizeof(ne10) atIndex:9];
-                [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:10];
-                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
-                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
-                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                [encoder setBytes:&ne2     length:sizeof(ne2)  atIndex:15];
-                [encoder setBytes:&nb0     length:sizeof(nb0)  atIndex:16];
-                [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:17];
-                [encoder setBytes:&nb2     length:sizeof(nb2)  atIndex:18];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -2163,7 +2268,64 @@ static void ggml_metal_encode_node(
 
                 id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_ssm_scan args = {
+                    /*.d_state =*/ d_state,
+                    /*.d_inner =*/ d_inner,
+                    /*.n_seq_tokens =*/ n_seq_tokens,
+                    /*.n_seqs =*/ n_seqs,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.nb20 =*/ nb20,
+                    /*.nb21 =*/ nb21,
+                    /*.nb22 =*/ nb22,
+                    /*.nb30 =*/ nb30,
+                    /*.nb31 =*/ nb31,
+                    /*.nb40 =*/ nb40,
+                    /*.nb41 =*/ nb41,
+                    /*.nb42 =*/ nb42,
+                    /*.nb50 =*/ nb50,
+                    /*.nb51 =*/ nb51,
+                    /*.nb52 =*/ nb52,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+                [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+                [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
+                [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:7];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_RWKV_WKV6:
+            {
+                const int64_t B = dst->src[5]->ne[1];
+                const int64_t T = dst->src[0]->ne[2];
+                const int64_t C = dst->ne[0];
+                const int64_t H = dst->src[0]->ne[1];
+
+                GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+                GGML_ASSERT(C % H == 0);
+                GGML_ASSERT(C / H == 64);
+
+                size_t offs_src3 = 0;
+                size_t offs_src4 = 0;
+                size_t offs_src5 = 0;
+
+                id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
+                id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
+                id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2173,31 +2335,52 @@ static void ggml_metal_encode_node(
                 [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
 
-                [encoder setBytes:&d_state      length:sizeof(d_state)      atIndex:7];
-                [encoder setBytes:&d_inner      length:sizeof(d_inner)      atIndex:8];
-                [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
-                [encoder setBytes:&n_seqs       length:sizeof(n_seqs)       atIndex:10];
+                [encoder setBytes:&B length:sizeof(B) atIndex:7];
+                [encoder setBytes:&T length:sizeof(T) atIndex:8];
+                [encoder setBytes:&C length:sizeof(C) atIndex:9];
+                [encoder setBytes:&H length:sizeof(H) atIndex:10];
 
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
-                [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
-                [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
-                [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
-                [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
-                [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
-                [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
-                [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
-                [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
-                [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
-                [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
+                [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
+            } break;
+        case GGML_OP_RWKV_WKV7:
+            {
+                const int64_t B = dst->src[6]->ne[1];
+                const int64_t T = dst->src[0]->ne[2];
+                const int64_t C = dst->ne[0];
+                const int64_t H = dst->src[0]->ne[1];
 
-                [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
+                GGML_ASSERT(C % H == 0);
+                GGML_ASSERT(C / H == 64);
+
+                size_t offs_src3 = 0;
+                size_t offs_src4 = 0;
+                size_t offs_src5 = 0;
+                size_t offs_src6 = 0;
+
+                id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
+                id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
+                id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
+                id id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+                [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+                [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
+                [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
+                [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:7];
+
+                [encoder setBytes:&B length:sizeof(B) atIndex:8];
+                [encoder setBytes:&T length:sizeof(T) atIndex:9];
+                [encoder setBytes:&C length:sizeof(C) atIndex:10];
+                [encoder setBytes:&H length:sizeof(H) atIndex:11];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
             } break;
         case GGML_OP_MUL_MAT:
             {
@@ -2458,171 +2641,180 @@ static void ggml_metal_encode_node(
                     [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                     [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                 } else {
-                    int nth0 = 32;
-                    int nth1 = 1;
-                    int nrows = 1;
-                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
                     id pipeline = nil;
 
+                    int nsg = 0; // number of simdgroups
+                    int nr0 = 0; // number of src0 rows per simdgroup
+                    int nr1 = 1; // number of src1 rows per threadgroup
+
+                    size_t smem = 0; // shared memory
+
                     // use custom matrix x vector kernel
                     switch (src0t) {
                         case GGML_TYPE_F32:
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                nsg = 1;
+                                nr0 = 1;
+                                nr1 = 4;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
-                                nrows = 4;
                             } break;
                         case GGML_TYPE_F16:
                             {
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 if (src1t == GGML_TYPE_F32) {
                                     if (ne11 * ne12 < 4) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
-                                        nrows = ne11;
+                                        nr1 = ne11;
                                     } else {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
-                                        nrows = 4;
+                                        nr1 = 4;
                                     }
                                 } else {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
-                                    nrows = 4;
+                                    nr1 = 4;
                                 }
                             } break;
                         case GGML_TYPE_BF16:
                             {
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 if (src1t == GGML_TYPE_F32) {
                                     if (ne11 * ne12 < 4) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
-                                        nrows = ne11;
+                                        nr1 = ne11;
                                     } else {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
-                                        nrows = 4;
+                                        nr1 = 4;
                                     }
                                 } else {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
-                                    nrows = 4;
+                                    nr1 = 4;
                                 }
                             } break;
                         case GGML_TYPE_Q4_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_0;
+                                nr0 = N_R0_Q4_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q4_1:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_1;
+                                nr0 = N_R0_Q4_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_0;
+                                nr0 = N_R0_Q5_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_1:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_1;
+                                nr0 = N_R0_Q5_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q8_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q8_0;
+                                nr0 = N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q2_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q2_K;
+                                nr0 = N_R0_Q2_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q3_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q3_K;
+                                nr0 = N_R0_Q3_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q4_K:
                             {
-                                nth0 = 4; //1;
-                                nth1 = 8; //32;
+                                nsg = N_SG_Q4_K;
+                                nr0 = N_R0_Q4_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q5_K;
+                                nr0 = N_R0_Q5_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q6_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q6_K;
+                                nr0 = N_R0_Q6_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_XXS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XXS;
+                                nr0 = N_R0_IQ2_XXS;
+                                smem = 256*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_XS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XS;
+                                nr0 = N_R0_IQ2_XS;
+                                smem = 512*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ3_XXS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_XXS;
+                                nr0 = N_R0_IQ3_XXS;
+                                smem = 256*4+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ3_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_S;
+                                nr0 = N_R0_IQ3_S;
+                                smem = 512*4;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_S;
+                                nr0 = N_R0_IQ2_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ1_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_S;
+                                nr0 = N_R0_IQ1_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ1_M:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_M;
+                                nr0 = N_R0_IQ1_M;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ4_NL:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_NL;
+                                nr0 = N_R0_IQ4_NL;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ4_XS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_XS;
+                                nr0 = N_R0_IQ4_XS;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
                             } break;
                         default:
@@ -2659,41 +2851,10 @@ static void ggml_metal_encode_node(
                     [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
-                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                        src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                        src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
-                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
-                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
-                        const int mem_size = 32*sizeof(float);
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q4_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q3_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q5_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q6_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    } else {
-                        const int64_t ny = (ne11 + nrows - 1)/nrows;
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    if (smem > 0) {
+                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
                     }
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
                 }
             } break;
         case GGML_OP_MUL_MAT_ID:
@@ -2719,20 +2880,19 @@ static void ggml_metal_encode_node(
                 // ne21 = n_rows
                 const int dst_rows = ne20*ne21;
                 const int dst_rows_min = n_as;
-                const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
+                const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
 
                 // max size of the rowids array in the kernel shared buffer
-                GGML_ASSERT(dst_rows <= dst_rows_max);
+                //GGML_ASSERT(dst_rows <= dst_rows_max);
 
                 // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                 // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                // !!!
-                // TODO: for now, always use mat-vec kernels until we figure out how to improve the
-                //       indirect matrix multiplication
-                // !!!
                 if ([device supportsFamily:MTLGPUFamilyApple7] &&
                         ne00 % 32 == 0 && ne00 >= 64 &&
-                        dst_rows > dst_rows_min) {
+                        //ne01 / ne02 >= 512 &&    // NOTE: this is based on Mixtral shapes, might need adjustments
+                        dst_rows >  dst_rows_min &&
+                        dst_rows <= dst_rows_max) {
+
                     // some Metal matrix data types require aligned pointers
                     // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
                     switch (src0->type) {
@@ -2799,146 +2959,155 @@ static void ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                 } else {
-                    int nth0 = 32;
-                    int nth1 = 1;
-                    int nrows = 1;
-                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
                     id pipeline = nil;
 
+                    int nsg = 0; // number of simdgroups
+                    int nr0 = 0; // number of src0 rows per simdgroup
+                    int nr1 = 1; // number of src1 rows per threadgroup
+
+                    size_t smem = 0; // shared memory
+
                     // use custom matrix x vector kernel
                     switch (src0t) {
                         case GGML_TYPE_F32:
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                nsg = 1;
+                                nr0 = 1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
                             } break;
                         case GGML_TYPE_F16:
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
                             } break;
                         case GGML_TYPE_BF16:
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q4_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_0;
+                                nr0 = N_R0_Q4_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q4_1:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_1;
+                                nr0 = N_R0_Q4_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_0;
+                                nr0 = N_R0_Q5_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_1:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_1;
+                                nr0 = N_R0_Q5_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q8_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q8_0;
+                                nr0 = N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q2_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q2_K;
+                                nr0 = N_R0_Q2_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q3_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q3_K;
+                                nr0 = N_R0_Q3_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q4_K:
                             {
-                                nth0 = 4; //1;
-                                nth1 = 8; //32;
+                                nsg = N_SG_Q4_K;
+                                nr0 = N_R0_Q4_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q5_K;
+                                nr0 = N_R0_Q5_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q6_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q6_K;
+                                nr0 = N_R0_Q6_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_XXS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XXS;
+                                nr0 = N_R0_IQ2_XXS;
+                                smem = 256*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_XS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XS;
+                                nr0 = N_R0_IQ2_XS;
+                                smem = 512*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ3_XXS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_XXS;
+                                nr0 = N_R0_IQ3_XXS;
+                                smem = 256*4+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ3_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_S;
+                                nr0 = N_R0_IQ3_S;
+                                smem = 512*4;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_S;
+                                nr0 = N_R0_IQ2_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ1_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_S;
+                                nr0 = N_R0_IQ1_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ1_M:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_M;
+                                nr0 = N_R0_IQ1_M;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ4_NL:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_NL;
+                                nr0 = N_R0_IQ4_NL;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ4_XS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_XS;
+                                nr0 = N_R0_IQ4_XS;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
                             } break;
                         default:
@@ -2949,7 +3118,7 @@ static void ggml_metal_encode_node(
                     };
 
                     if (ggml_is_quantized(src0t)) {
-                        GGML_ASSERT(ne00 >= nth0*nth1);
+                        GGML_ASSERT(ne00 >= nsg*nr0);
                     }
 
                     ggml_metal_kargs_mul_mv_id args = {
@@ -2982,43 +3151,12 @@ static void ggml_metal_encode_node(
                     [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
 
                     const int64_t _ne1 = 1;
-                    const int tgz = dst_rows;
+                    const int64_t ne123 = dst_rows;
 
-                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                            src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                            src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
-                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
-                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
-                        const int mem_size = 32*sizeof(float);
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q4_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q3_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q5_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q6_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    } else {
-                        const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    if (smem > 0) {
+                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
                     }
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
                 }
             } break;
         case GGML_OP_GET_ROWS:
@@ -3052,19 +3190,22 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("not implemented");
                 }
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_get_rows args = {
+                    /*.ne00 =*/ ne00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.ne10 =*/ ne10,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
                 [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
-                [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
-                [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
-                [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
-                [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
-                [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:9];
-                [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:10];
+                [encoder setBytes:&args length:sizeof(args) atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
             } break;
@@ -3102,6 +3243,42 @@ static void ggml_metal_encode_node(
 
                 const int64_t nrows = ggml_nrows(src0);
 
+                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_L2_NORM:
+            {
+                GGML_ASSERT(ne00 % 4 == 0);
+                GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+                float eps;
+                memcpy(&eps, dst->op_params, sizeof(float));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
+
+                int nth = 32; // SIMD width
+
+                while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                    nth *= 2;
+                }
+
+                nth = MIN(nth, ne00/4);
+
+                ggml_metal_kargs_l2_norm args = {
+                    /*.ne00   =*/ ne00,
+                    /*.ne00_4 =*/ ne00/4,
+                    /*.nb01   =*/ nb01,
+                    /*.eps    =*/ eps,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+                const int64_t nrows = ggml_nrows(src0);
+
                 [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
         case GGML_OP_GROUP_NORM:
@@ -3121,18 +3298,21 @@ static void ggml_metal_encode_node(
 
                 id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_group_norm args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.n_groups =*/ n_groups,
+                    /*.eps =*/ eps,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
                 [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne01     length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&ne02     length:sizeof( int64_t) atIndex:4];
-                [encoder setBytes:&nb00     length:sizeof(uint64_t) atIndex:5];
-                [encoder setBytes:&nb01     length:sizeof(uint64_t) atIndex:6];
-                [encoder setBytes:&nb02     length:sizeof(uint64_t) atIndex:7];
-                [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
-                [encoder setBytes:&eps      length:sizeof(   float) atIndex:9];
+                [encoder setBytes:&args     length:sizeof(args)     atIndex:2];
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -3290,8 +3470,8 @@ static void ggml_metal_encode_node(
 
                 const int32_t CHW = IC * KH * KW;
 
-                const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
-                const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+                const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+                const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
 
                 id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
 
@@ -3313,27 +3493,30 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_im2col args = {
+                    /*.ofs0 =*/ ofs0,
+                    /*.ofs1 =*/ ofs1,
+                    /*.IW   =*/ IW,
+                    /*.IH   =*/ IH,
+                    /*.CHW  =*/ CHW,
+                    /*.s0   =*/ s0,
+                    /*.s1   =*/ s1,
+                    /*.p0   =*/ p0,
+                    /*.p1   =*/ p1,
+                    /*.d0   =*/ d0,
+                    /*.d1   =*/ d1,
+                    /*.N    =*/ N,
+                    /*.KH   =*/ KH,
+                    /*.KW   =*/ KW,
+                    /*.KHW  =*/ KH * KW,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src1 offset:offs_src1       atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
-                [encoder setBytes:&ofs0    length:sizeof(int32_t) atIndex:2];
-                [encoder setBytes:&ofs1    length:sizeof(int32_t) atIndex:3];
-                [encoder setBytes:&IW      length:sizeof(int32_t) atIndex:4];
-                [encoder setBytes:&IH      length:sizeof(int32_t) atIndex:5];
-                [encoder setBytes:&CHW     length:sizeof(int32_t) atIndex:6];
-                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:7];
-                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:8];
-                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:9];
-                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:10];
-                [encoder setBytes:&d0      length:sizeof(int32_t) atIndex:11];
-                [encoder setBytes:&d1      length:sizeof(int32_t) atIndex:12];
+                [encoder setBytes:&args length:sizeof(args)       atIndex:2];
 
                 if (is_gt_mttpt) {
-                    [encoder setBytes:&N   length:sizeof(int32_t) atIndex:13];
-                    [encoder setBytes:&KH  length:sizeof(int32_t) atIndex:14];
-                    [encoder setBytes:&KW  length:sizeof(int32_t) atIndex:15];
-
                     const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
 
                     const int64_t  quotient  = N / n_threads + (N % n_threads > 0 ? 1 : 0);
@@ -3373,16 +3556,20 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
+                ggml_metal_kargs_conv_transpose_1d args = {
+                    /*.IC =*/ IC,
+                    /*.IL =*/ IL,
+                    /*.K  =*/ K,
+                    /*.s0 =*/ s0,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0         atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1         atIndex:1];
                 [encoder setBuffer:id_dst  offset:offs_dst          atIndex:2];
-                [encoder setBytes:&IC      length:sizeof( int32_t)  atIndex:3];
-                [encoder setBytes:&IL      length:sizeof( int32_t)  atIndex:4];
-                [encoder setBytes:&K       length:sizeof( int32_t)  atIndex:5];
-                [encoder setBytes:&s0      length:sizeof( int32_t)  atIndex:6];
-                [encoder setBytes:&nb0     length:sizeof(uint64_t)  atIndex:7];
-                [encoder setBytes:&nb1     length:sizeof(uint64_t)  atIndex:8];
+                [encoder setBytes:&args    length:sizeof(args)       atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -3397,30 +3584,33 @@ static void ggml_metal_encode_node(
 
                 const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_upscale args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0 =*/ ne0,
+                    /*.ne1 =*/ ne1,
+                    /*.ne2 =*/ ne2,
+                    /*.ne3 =*/ ne3,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                    /*.nb3 =*/ nb3,
+                    /*.sf0 =*/ sf0,
+                    /*.sf1 =*/ sf1,
+                    /*.sf2 =*/ sf2,
+                    /*.sf3 =*/ sf3
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
-                [encoder setBytes:&sf0  length:sizeof(sf0)  atIndex:18];
-                [encoder setBytes:&sf1  length:sizeof(sf1)  atIndex:19];
-                [encoder setBytes:&sf2  length:sizeof(sf2)  atIndex:20];
-                [encoder setBytes:&sf3  length:sizeof(sf3)  atIndex:21];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
@@ -3432,26 +3622,29 @@ static void ggml_metal_encode_node(
 
                 id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_pad args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0 =*/ ne0,
+                    /*.ne1 =*/ ne1,
+                    /*.ne2 =*/ ne2,
+                    /*.ne3 =*/ ne3,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                    /*.nb3 =*/ nb3
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN(1024, ne0);
 
@@ -3466,24 +3659,31 @@ static void ggml_metal_encode_node(
 
                 id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
 
+                ggml_metal_kargs_pad_reflect_1d args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0 =*/ ne0,
+                    /*.ne1 =*/ ne1,
+                    /*.ne2 =*/ ne2,
+                    /*.ne3 =*/ ne3,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                    /*.nb3 =*/ nb3,
+                    /*.p0 =*/ p0,
+                    /*.p1 =*/ p1
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:11];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:12];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:13];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:14];
-                [encoder setBytes:&p0   length:sizeof(p0)   atIndex:15];
-                [encoder setBytes:&p1   length:sizeof(p1)   atIndex:16];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN(1024, ne0);
 
@@ -3531,12 +3731,15 @@ static void ggml_metal_encode_node(
 
                 id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_arange args = {
+                    /*.ne0 =*/ ne0,
+                    /*.start =*/ start,
+                    /*.step =*/ step
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_dst  offset:offs_dst    atIndex:0];
-                [encoder setBytes:&ne0   length:sizeof(ne0)   atIndex:1];
-                [encoder setBytes:&start length:sizeof(start) atIndex:2];
-                [encoder setBytes:&step  length:sizeof(step)  atIndex:3];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:0];
+                [encoder setBytes:&args length:sizeof(args) atIndex:1];
 
                 const int nth = MIN(1024, ne0);
 
@@ -3553,13 +3756,16 @@ static void ggml_metal_encode_node(
 
                 id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_timestep_embedding args = {
+                    /*.nb1 =*/ nb1,
+                    /*.dim =*/ dim,
+                    /*.max_period =*/ max_period
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&nb1   length:sizeof(nb1) atIndex:2];
-                [encoder setBytes:&dim   length:sizeof(dim) atIndex:3];
-                [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN(1024, half);
 
@@ -3592,12 +3798,15 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_argsort args = {
+                    /*.ncols =*/ ne00,
+                    /*.ncols_pad =*/ ne00_padded
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
@@ -3611,11 +3820,14 @@ static void ggml_metal_encode_node(
 
                 id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_leaky_relu args = {
+                    /*.slope =*/ slope
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
-                [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+                [encoder setBytes:&args length:sizeof(args)   atIndex:2];
 
                 const int64_t n = ggml_nelements(dst);
 
@@ -3629,7 +3841,9 @@ static void ggml_metal_encode_node(
                 GGML_ASSERT(src0->type == GGML_TYPE_F32);
                 GGML_ASSERT(src1->type == src2->type);
 
-                GGML_ASSERT(ggml_are_same_shape (src1, src2));
+                //GGML_ASSERT(ggml_are_same_shape (src1, src2));
+                GGML_ASSERT(ne11 == ne21);
+                GGML_ASSERT(ne12 == ne22);
 
                 struct ggml_tensor * src3 = node->src[3];
 
@@ -3676,125 +3890,161 @@ static void ggml_metal_encode_node(
 
                 // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
                 //       for now avoiding mainly to keep the number of templates/kernels a bit lower
-                if (ne01 >= 4 || (ne00%128 != 0)) {
+                //       these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
+                if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
                     switch (src1->type) {
                         case GGML_TYPE_F16:
                             {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
+                                if (ne00 == 192 && ne20 == 128) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
+                                } else {
+                                    switch (ne00) {
+                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break;
+                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+                                        default:
+                                                  {
+                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                      GGML_LOG_ERROR("add template specialization for this size\n");
+                                                      GGML_ABORT("add template specialization for this size");
+                                                  }
+                                    }
                                 }
                             } break;
                         case GGML_TYPE_BF16:
                             {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
+                                if (ne00 == 192 && ne20 == 128) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
+                                } else {
+                                    switch (ne00) {
+                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
+                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
+                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
+                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
+                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
+                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break;
+                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
+                                        default:
+                                                  {
+                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                      GGML_LOG_ERROR("add template specialization for this size\n");
+                                                      GGML_ABORT("add template specialization for this size");
+                                                  }
+                                    }
                                 }
                             } break;
                         case GGML_TYPE_Q4_0:
                             {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
+                                if (ne00 == 192 && ne20 == 128) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
+                                } else {
+                                    switch (ne00) {
+                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
+                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
+                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
+                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
+                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
+                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break;
+                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
+                                        default:
+                                                  {
+                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                      GGML_LOG_ERROR("add template specialization for this size\n");
+                                                      GGML_ABORT("add template specialization for this size");
+                                                  }
+                                    }
                                 }
                             } break;
                         case GGML_TYPE_Q4_1:
                             {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
+                                if (ne00 == 192 && ne20 == 128) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
+                                } else {
+                                    switch (ne00) {
+                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
+                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
+                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
+                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
+                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
+                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break;
+                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
+                                        default:
+                                                  {
+                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                      GGML_LOG_ERROR("add template specialization for this size\n");
+                                                      GGML_ABORT("add template specialization for this size");
+                                                  }
+                                    }
                                 }
                             } break;
                         case GGML_TYPE_Q5_0:
                             {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
+                                if (ne00 == 192 && ne20 == 128) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
+                                } else {
+                                    switch (ne00) {
+                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
+                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
+                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
+                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
+                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
+                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break;
+                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
+                                        default:
+                                                  {
+                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                      GGML_LOG_ERROR("add template specialization for this size\n");
+                                                      GGML_ABORT("add template specialization for this size");
+                                                  }
+                                    }
                                 }
                             } break;
                         case GGML_TYPE_Q5_1:
                             {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
+                                if (ne00 == 192 && ne20 == 128) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
+                                } else {
+                                    switch (ne00) {
+                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
+                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
+                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
+                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
+                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
+                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break;
+                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
+                                        default:
+                                                  {
+                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                      GGML_LOG_ERROR("add template specialization for this size\n");
+                                                      GGML_ABORT("add template specialization for this size");
+                                                  }
+                                    }
                                 }
                             } break;
                         case GGML_TYPE_Q8_0:
                             {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
+                                if (ne00 == 192 && ne20 == 128) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
+                                } else {
+                                    switch (ne00) {
+                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
+                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
+                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
+                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
+                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
+                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break;
+                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
+                                        default:
+                                                  {
+                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                      GGML_LOG_ERROR("add template specialization for this size\n");
+                                                      GGML_ABORT("add template specialization for this size");
+                                                  }
+                                    }
                                 }
                             } break;
                         default:
@@ -3826,6 +4076,42 @@ static void ggml_metal_encode_node(
                                         }
                                 }
                             } break;
+                        case 192:
+                            {
+                                if (ne20 == 128) {
+                                    switch (src1->type) {
+                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break;
+                                        case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break;
+                                        case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break;
+                                        case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break;
+                                        case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break;
+                                        case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break;
+                                        case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break;
+                                        default:
+                                            {
+                                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                                GGML_LOG_ERROR("add template specialization for this type\n");
+                                                GGML_ABORT("add template specialization for this type");
+                                            }
+                                    }
+                                } else {
+                                    switch (src1->type) {
+                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break;
+                                        case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break;
+                                        case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break;
+                                        case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break;
+                                        case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break;
+                                        case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break;
+                                        case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break;
+                                        default:
+                                            {
+                                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                                GGML_LOG_ERROR("add template specialization for this type\n");
+                                                GGML_ABORT("add template specialization for this type");
+                                            }
+                                    }
+                                }
+                            } break;
                         case 256:
                             {
                                 switch (src1->type) {
@@ -3863,9 +4149,12 @@ static void ggml_metal_encode_node(
                     /*.ne11          =*/ ne11,
                     /*.ne_12_2       =*/ ne12,
                     /*.ne_12_3       =*/ ne13,
-                    /*.nb_12_1       =*/ nb11,
-                    /*.nb_12_2       =*/ nb12,
-                    /*.nb_12_3       =*/ nb13,
+                    /*.nb11          =*/ nb11,
+                    /*.nb12          =*/ nb12,
+                    /*.nb13          =*/ nb13,
+                    /*.nb21          =*/ nb21,
+                    /*.nb22          =*/ nb22,
+                    /*.nb23          =*/ nb23,
                     /*.nb31          =*/ nb31,
                     /*.ne1           =*/ ne1,
                     /*.ne2           =*/ ne2,
@@ -3944,10 +4233,9 @@ static void ggml_metal_encode_node(
                     // ne00*(nsg)
                     // each simdgroup has a full f16 head vector in shared mem to accumulate results
                     //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
 
                     int64_t nsgmax = 2;
-
                     while (true) {
                         const size_t smem = FATTN_SMEM(nsgmax);
                         if (smem > device.maxThreadgroupMemoryLength) {
@@ -4191,21 +4479,24 @@ static void ggml_metal_encode_node(
                 const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
                 const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_pool_2d args_pool_2d = {
+                    /* .k0 = */ k0,
+                    /* .k1 = */ k1,
+                    /* .s0 = */ s0,
+                    /* .s1 = */ s1,
+                    /* .p0 = */ p0,
+                    /* .p1 = */ p1,
+                    /* .IH = */ IH,
+                    /* .IW = */ IW,
+                    /* .OH = */ OH,
+                    /* .OW = */ OW,
+                    /* .parallel_elements = */ parallel_elements
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0       atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
-                [encoder setBytes:&k0      length:sizeof(int32_t) atIndex:2];
-                [encoder setBytes:&k1      length:sizeof(int32_t) atIndex:3];
-                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:4];
-                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:5];
-                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:6];
-                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:7];
-                [encoder setBytes:&IH      length:sizeof(int64_t) atIndex:8];
-                [encoder setBytes:&IW      length:sizeof(int64_t) atIndex:9];
-                [encoder setBytes:&OH      length:sizeof(int64_t) atIndex:10];
-                [encoder setBytes:&OW      length:sizeof(int64_t) atIndex:11];
-                [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
             } break;
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
index bb0ff6688..ede9d1e6c 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
@@ -3,8 +3,7 @@
 #if defined(GGML_METAL_EMBED_LIBRARY)
 __embed_ggml-common.h__
 #else
-// TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift
-#include "../ggml-common.h"
+#include "ggml-common.h"
 #endif
 #include "ggml-metal-impl.h"
 
@@ -49,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
 
 template 
 void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
-    reg = (type4)(*(src + il));
+    reg = (type4)(*(src));
 }
 
 #if defined(GGML_METAL_USE_BF16)
@@ -57,6 +56,11 @@ template 
 void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
     reg = (type4x4)(*src);
 }
+
+template 
+void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
+    reg = (type4)(*(src));
+}
 #endif
 
 template 
@@ -955,45 +959,22 @@ kernel void kernel_neg(
 kernel void kernel_sum_rows(
         device const float * src0,
         device       float * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
+        constant ggml_metal_kargs_sum_rows & args,
         uint3 tpig[[thread_position_in_grid]]) {
     int64_t i3 = tpig.z;
     int64_t i2 = tpig.y;
     int64_t i1 = tpig.x;
 
-    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
         return;
     }
 
-    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
-    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+    device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
+    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
 
     float row_sum = 0;
 
-    for (int64_t i0 = 0; i0 < ne00; i0++) {
+    for (int64_t i0 = 0; i0 < args.ne00; i0++) {
         row_sum += src_row[i0];
     }
 
@@ -1005,36 +986,29 @@ kernel void kernel_soft_max(
         device const  char * src0,
         device const  char * src1,
         device        char * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint32_t & n_head_log2,
+        constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (ne02*ne01);
-    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
-    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+    const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
+    const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
+    const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
 
-    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
-    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
+    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*args.ne00 : nullptr;
+    device       float * pdst  = (device       float *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
 
     float slope = 1.0f;
 
     // ALiBi
-    if (max_bias > 0.0f) {
+    if (args.max_bias > 0.0f) {
         const int64_t h = i02;
 
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+        const float base = h < args.n_head_log2 ? args.m0 : args.m1;
+        const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
         slope = pow(base, exp);
     }
@@ -1042,8 +1016,8 @@ kernel void kernel_soft_max(
     // parallel max
     float lmax = -INFINITY;
 
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+        lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
 
     // find the max value in the block
@@ -1067,8 +1041,8 @@ kernel void kernel_soft_max(
 
     // parallel sum
     float lsum = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+        const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
         lsum += exp_psrc0;
         pdst[i00] = exp_psrc0;
     }
@@ -1098,7 +1072,7 @@ kernel void kernel_soft_max(
 
     const float inv_sum = 1.0f/sum;
 
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
         pdst[i00] *= inv_sum;
     }
 }
@@ -1108,35 +1082,28 @@ kernel void kernel_soft_max_4(
         device const  char * src0,
         device const  char * src1,
         device        char * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint32_t & n_head_log2,
+        constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (ne02*ne01);
-    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
-    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+    const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
+    const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
+    const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
 
-    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
-    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
-    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+    device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
+    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*args.ne00/4 : nullptr;
+    device       float4 * pdst4 = (device       float4 *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
 
     float slope = 1.0f;
 
-    if (max_bias > 0.0f) {
+    if (args.max_bias > 0.0f) {
         const int64_t h = i02;
 
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+        const float base = h < args.n_head_log2 ? args.m0 : args.m1;
+        const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
         slope = pow(base, exp);
     }
@@ -1144,8 +1111,8 @@ kernel void kernel_soft_max_4(
     // parallel max
     float4 lmax4 = -INFINITY;
 
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+        lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -1170,8 +1137,8 @@ kernel void kernel_soft_max_4(
 
     // parallel sum
     float4 lsum4 = 0.0f;
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+        const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
@@ -1203,7 +1170,7 @@ kernel void kernel_soft_max_4(
 
     const float inv_sum = 1.0f/sum;
 
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
         pdst4[i00] *= inv_sum;
     }
 }
@@ -1219,27 +1186,23 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
 kernel void kernel_diag_mask_inf(
         device const float * src0,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant       int & n_past,
+        constant ggml_metal_kargs_diag_mask_inf & args,
         uint3 tpig[[thread_position_in_grid]]) {
     const int64_t i02 = tpig[2];
     const int64_t i01 = tpig[1];
     const int64_t i00 = tpig[0];
 
-    if (i00 > n_past + i01) {
-        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    if (i00 > args.n_past + i01) {
+        dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY;
     } else {
-        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+        dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00];
     }
 }
 
 kernel void kernel_diag_mask_inf_8(
         device const float4 * src0,
         device       float4 * dst,
-        constant    int64_t & ne00,
-        constant    int64_t & ne01,
-        constant        int & n_past,
+        constant ggml_metal_kargs_diag_mask_inf & args,
         uint3 tpig[[thread_position_in_grid]]) {
 
     const int64_t i = 2*tpig[0];
@@ -1247,42 +1210,26 @@ kernel void kernel_diag_mask_inf_8(
     dst[i+0] = src0[i+0];
     dst[i+1] = src0[i+1];
     int64_t i4 = 4*i;
-    const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
-    const int64_t i01 = i4/(ne00);      i4 -= i01*ne00;
+    const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01;
+    const int64_t i01 = i4/(args.ne00);      i4 -= i01*args.ne00;
     const int64_t i00 = i4;
     for (int k = 3; k >= 0; --k) {
-        if (i00 + 4 + k <= n_past + i01) {
+        if (i00 + 4 + k <= args.n_past + i01) {
             break;
         }
         dst[i+1][k] = -INFINITY;
-        if (i00 + k > n_past + i01) {
+        if (i00 + k > args.n_past + i01) {
             dst[i][k] = -INFINITY;
         }
     }
 }
 
 // ref: ggml.c:ggml_compute_forward_ssm_conv_f32
-// TODO: optimize
 kernel void kernel_ssm_conv_f32(
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_ssm_conv & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -1290,15 +1237,15 @@ kernel void kernel_ssm_conv_f32(
     const int64_t i2 = tgpig.y;
     const int64_t i3 = tgpig.z;
 
-    const int64_t nc  = ne10;
-  //const int64_t ncs = ne00;
-  //const int64_t nr  = ne01;
-  //const int64_t n_t = ne1;
-  //const int64_t n_s = ne2;
+    const int64_t nc  = args.ne10;
+  //const int64_t ncs = args.ne00;
+  //const int64_t nr  = args.ne01;
+  //const int64_t n_t = args.ne1;
+  //const int64_t n_s = args.ne2;
 
-    device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
-    device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
-    device       float * x = (device       float *) ((device       char *) dst  + ir*nb0  + i2*nb1  + i3*nb2);
+    device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
+    device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
+    device       float * x = (device       float *) ((device       char *) dst  + ir*args.nb0  + i2*args.nb1  + i3*args.nb2);
 
     float sumf = 0.0f;
 
@@ -1310,7 +1257,6 @@ kernel void kernel_ssm_conv_f32(
 }
 
 // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
-// TODO: optimize
 kernel void kernel_ssm_scan_f32(
         device const void * src0,
         device const void * src1,
@@ -1319,48 +1265,27 @@ kernel void kernel_ssm_scan_f32(
         device const void * src4,
         device const void * src5,
         device      float * dst,
-        constant  int64_t & d_state,
-        constant  int64_t & d_inner,
-        constant  int64_t & n_seq_tokens,
-        constant  int64_t & n_seqs,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant uint64_t & nb20,
-        constant uint64_t & nb21,
-        constant uint64_t & nb22,
-        constant uint64_t & nb30,
-        constant uint64_t & nb31,
-        constant uint64_t & nb40,
-        constant uint64_t & nb41,
-        constant uint64_t & nb42,
-        constant uint64_t & nb50,
-        constant uint64_t & nb51,
-        constant uint64_t & nb52,
+        constant ggml_metal_kargs_ssm_scan & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
     const int64_t ir = tgpig.x;
     const int64_t i3 = tgpig.y;
 
-    const int64_t nc  = d_state;
-  //const int64_t nr  = d_inner;
-    const int64_t n_t = n_seq_tokens;
-  //const int64_t n_s = n_seqs;
+    const int64_t nc  = args.d_state;
+    // const int64_t nr  = args.d_inner;
+    const int64_t n_t = args.n_seq_tokens;
+    // const int64_t n_s = args.n_seqs;
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
-        device const float * x  = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
-        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
-        device const float * A  = (device const float *) ((device const char *) src3 + ir*nb31);
-        device const float * B  = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
-        device const float * C  = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
-        device       float * y  = (device       float *) ((device       char *) dst  + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
-        device       float * s  = (device       float *) ((device       char *) dst  + ir*nb01 + i3*nb02 +    nb13);
+        device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
+        device const float * x  = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
+        device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
+        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31);
+        device const float * B  = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
+        device const float * C  = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
+        device       float * y  = (device       float *) ((device       char *) dst  + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
+        device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb01 + i3*args.nb02 +    args.nb13);
 
         if (i2 > 0) {
             s0 = s;
@@ -1382,6 +1307,184 @@ kernel void kernel_ssm_scan_f32(
     }
 }
 
+kernel void kernel_rwkv_wkv6_f32(
+    device const float * k,
+    device const float * v,
+    device const float * r,
+    device const float * tf,
+    device const float * td,
+    device const float * state_in,
+    device       float * dst,
+    constant    uint & B,
+    constant    uint & T,
+    constant    uint & C,
+    constant    uint & H,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]])  {
+
+    const uint head_size = 64; // TODO: support head_size = 128
+    const uint batch_id = tgpig.x / H;
+    const uint head_id = tgpig.x % H;
+    const uint tid = tpitg.x;
+
+    if (batch_id >= B || head_id >= H) {
+        return;
+    }
+
+    const uint state_size = C * head_size;
+    const uint n_seq_tokens = T / B;
+
+    threadgroup float _k[head_size];
+    threadgroup float _r[head_size];
+    threadgroup float _tf[head_size];
+    threadgroup float _td[head_size];
+
+    float state[head_size];
+
+    for (uint i = 0; i < head_size; i++) {
+        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+                          + i * head_size + tid];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    _tf[tid] = tf[head_id * head_size + tid];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+    for (uint t = start_t; t < end_t; t += C) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        const float v_val = v[t];
+        float y = 0.0;
+
+        for (uint j = 0; j < head_size; j += 4) {
+            float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+            float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+            float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
+            float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
+            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+            float4 kv = k_vec * v_val;
+
+            float4 temp = tf_vec * kv + s_vec;
+            y += dot(r_vec, temp);
+
+            s_vec = s_vec * td_vec + kv;
+            state[j]   = s_vec[0];
+            state[j+1] = s_vec[1];
+            state[j+2] = s_vec[2];
+            state[j+3] = s_vec[3];
+        }
+
+        dst[t] = y;
+    }
+
+    for (uint i = 0; i < head_size; i++) {
+        dst[T * C + batch_id * state_size + head_id * head_size * head_size
+            + i * head_size + tid] = state[i];
+    }
+}
+
+kernel void kernel_rwkv_wkv7_f32(
+    device const float * r,
+    device const float * w,
+    device const float * k,
+    device const float * v,
+    device const float * a,
+    device const float * b,
+    device const float * state_in,
+    device       float * dst,
+    constant    uint & B,
+    constant    uint & T,
+    constant    uint & C,
+    constant    uint & H,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]])  {
+
+    const uint head_size = 64; // TODO: support head_size = 128
+    const uint batch_id = tgpig.x / H;
+    const uint head_id = tgpig.x % H;
+    const uint tid = tpitg.x;
+
+    if (batch_id >= B || head_id >= H) {
+        return;
+    }
+
+    const uint state_size = C * head_size;
+    const uint n_seq_tokens = T / B;
+
+    threadgroup float _r[head_size];
+    threadgroup float _w[head_size];
+    threadgroup float _k[head_size];
+    threadgroup float _a[head_size];
+    threadgroup float _b[head_size];
+
+    float state[head_size];
+
+    for (uint i = 0; i < head_size; i++) {
+        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+                          + tid * head_size + i];
+    }
+
+    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+    for (uint t = start_t; t < end_t; t += C) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        _r[tid] = r[t];
+        _w[tid] = w[t];
+        _k[tid] = k[t];
+        _a[tid] = a[t];
+        _b[tid] = b[t];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        const float v_val = v[t];
+        float y = 0.0, sa = 0.0;
+
+        float4 sa_vec(0.0);
+
+        for (uint j = 0; j < head_size; j += 4) {
+            float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
+            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+            sa_vec += a_vec * s_vec;
+        }
+        sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
+
+        for (uint j = 0; j < head_size; j += 4) {
+            float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+            float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
+            float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+            float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
+            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+            float4 kv = k_vec * v_val;
+
+            s_vec = s_vec * w_vec + kv + sa * b_vec;
+            y += dot(s_vec, r_vec);
+
+            state[j]   = s_vec[0];
+            state[j+1] = s_vec[1];
+            state[j+2] = s_vec[2];
+            state[j+3] = s_vec[3];
+        }
+
+        dst[t] = y;
+    }
+
+    for (uint i = 0; i < head_size; i++) {
+        dst[T * C + batch_id * state_size + head_id * head_size * head_size
+            + tid * head_size + i] = state[i];
+    }
+}
+
 kernel void kernel_argmax(
         device   const void * x,
         device      int32_t * dst,
@@ -1550,25 +1653,61 @@ kernel void kernel_rms_norm(
     }
 }
 
+kernel void kernel_l2_norm(
+        constant ggml_metal_kargs_l2_norm & args,
+        device const char * src0,
+        device       char * dst,
+        threadgroup float * shmem_f32 [[threadgroup(0)]],
+        uint   tgpig[[threadgroup_position_in_grid]],
+        ushort tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort   ntg[[threads_per_threadgroup]]) {
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = 0.0f;
+    }
+
+    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+    float sumf = 0.0f;
+
+    // parallel sum
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        sumf += dot(x[i00], x[i00]);
+    }
+    sumf = simd_sum(sumf);
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float scale = 1.0f/sqrt(max(sumf, args.eps));
+
+    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        y[i00] = x[i00] * scale;
+    }
+}
+
 kernel void kernel_group_norm(
         device const float * src0,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int32_t & n_groups,
-        constant     float & eps,
+        constant ggml_metal_kargs_group_norm & args,
         threadgroup float  * buf [[threadgroup(0)]],
         uint tgpig[[threadgroup_position_in_grid]],
         uint tpitg[[thread_position_in_threadgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    const int64_t ne = ne00*ne01*ne02;
-    const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+    const int64_t ne = args.ne00*args.ne01*args.ne02;
+    const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups);
 
     int start = tgpig * gs;
     int end   = start + gs;
@@ -1632,7 +1771,7 @@ kernel void kernel_group_norm(
     }
 
     const float variance = tmp / gs;
-    const float scale = 1.0f/sqrt(variance + eps);
+    const float scale = 1.0f/sqrt(variance + args.eps);
     for (int j = start; j < end; j += ntg) {
         dst[j] *= scale;
     }
@@ -1726,14 +1865,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
     return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
 }
 
-// putting them in the kernel cause a significant performance penalty
-#define N_DST 4        // each SIMD group works on 4 rows
-#define N_SIMDGROUP 2  // number of SIMD groups in a thread group
-//Note: This is a template, but strictly speaking it only applies to
-//      quantizations where the block size is 32. It also does not
-//      guard against the number of rows not being divisible by
-//      N_DST, so this is another explicit assumption of the implementation.
-template
+template
 void mul_vec_q_n_f32_impl(
         args_t args,
         device const char * src0,
@@ -1749,7 +1881,7 @@ void mul_vec_q_n_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -1761,15 +1893,15 @@ void mul_vec_q_n_f32_impl(
     device const float        * y = (device const float        *) (src1 + offset1);
 
     // pointers to src0 rows
-    device const block_q_type * ax[nr];
-    for (int row = 0; row < nr; ++row) {
+    device const block_q_type * ax[nr0];
+    for (int row = 0; row < nr0; ++row) {
         const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
         ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
     }
 
     float yl[16]; // src1 vector cache
-    float sumf[nr] = {0.f};
+    float sumf[nr0] = {0.f};
 
     const short ix = (tiisg/2);
     const short il = (tiisg%2)*8;
@@ -1781,7 +1913,7 @@ void mul_vec_q_n_f32_impl(
         float sumy[2] = { 0.f, 0.f };
 
 #pragma unroll
-        for (int i = 0; i < 8; i += 2) {
+        for (short i = 0; i < 8; i += 2) {
             sumy[0]  += yb[i +  0] + yb[i +  1];
             yl[i + 0] = yb[i +  0];
             yl[i + 1] = yb[i +  1]/256.f;
@@ -1792,7 +1924,7 @@ void mul_vec_q_n_f32_impl(
         }
 
 #pragma unroll
-        for (int row = 0; row < nr; row++) {
+        for (short row = 0; row < nr0; row++) {
             sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
         }
 
@@ -1801,7 +1933,7 @@ void mul_vec_q_n_f32_impl(
 
     device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
 
-    for (int row = 0; row < nr; ++row) {
+    for (int row = 0; row < nr0; ++row) {
         const float tot = simd_sum(sumf[row]);
 
         if (tiisg == 0 && first_row + row < args.ne01) {
@@ -1818,7 +1950,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -1829,7 +1961,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+     mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -1840,7 +1972,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -1851,12 +1983,12 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 #define NB_Q8_0 8
 
-template
+template
 void kernel_mul_mv_q8_0_f32_impl(
         args_t args,
         device const char * src0,
@@ -1866,16 +1998,13 @@ void kernel_mul_mv_q8_0_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
-    const int nr  = N_DST;
-    const int nsg = N_SIMDGROUP;
-    const int nw  = N_SIMDWIDTH;
-
     const int nb = args.ne00/QK8_0;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0*nsg + sgitg)*nr;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -1887,15 +2016,15 @@ void kernel_mul_mv_q8_0_f32_impl(
     device const float      * y = (device const float      *) (src1 + offset1);
 
     // pointers to src0 rows
-    device const block_q8_0 * ax[nr];
-    for (int row = 0; row < nr; ++row) {
+    device const block_q8_0 * ax[nr0];
+    for (int row = 0; row < nr0; ++row) {
         const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
         ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
     }
 
     float yl[NB_Q8_0];
-    float sumf[nr] = { 0.f };
+    float sumf[nr0] = { 0.f };
 
     const short ix = tiisg/4;
     const short il = tiisg%4;
@@ -1908,7 +2037,7 @@ void kernel_mul_mv_q8_0_f32_impl(
             yl[i] = yb[i];
         }
 
-        for (int row = 0; row < nr; row++) {
+        for (short row = 0; row < nr0; row++) {
             device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
             float sumq = 0.f;
             for (short iq = 0; iq < NB_Q8_0; ++iq) {
@@ -1922,7 +2051,7 @@ void kernel_mul_mv_q8_0_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < nr; ++row) {
+    for (int row = 0; row < nr0; ++row) {
         const float tot = simd_sum(sumf[row]);
 
         if (tiisg == 0 && first_row + row < args.ne01) {
@@ -1940,7 +2069,7 @@ kernel void kernel_mul_mv_q8_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 // mat-vec kernel processing in chunks of float4
@@ -2277,9 +2406,9 @@ void kernel_mul_mv_impl(
                 sumf += (T0) x[i] * (T1) y[i];
             }
 
-            float all_sum = simd_sum(sumf);
+            float sum_all = simd_sum(sumf);
             if (tiisg == 0) {
-                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
             }
         }
     } else {
@@ -2300,10 +2429,10 @@ void kernel_mul_mv_impl(
                 sumf += dot((float4) x4[i], (float4) y4[i]);
             }
 
-            float all_sum = simd_sum(sumf);
+            float sum_all = simd_sum(sumf);
             if (tiisg == 0) {
-                for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
-                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+                for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
             }
         }
     }
@@ -2365,9 +2494,9 @@ kernel void kernel_mul_mv_1row(
         for (int i = tiisg; i < args.ne00; i += 32) {
             sumf += (float) x[i] * (float) y[i];
         }
-        float all_sum = simd_sum(sumf);
+        float sum_all = simd_sum(sumf);
         if (tiisg == 0) {
-            dst_f32[r0] = all_sum;
+            dst_f32[r0] = sum_all;
         }
     } else {
         device const T4     * x4 = (device const T4     *) x;
@@ -2377,11 +2506,11 @@ kernel void kernel_mul_mv_1row(
             sumf += dot((float4) x4[i], y4[i]);
         }
 
-        float all_sum = simd_sum(sumf);
+        float sum_all = simd_sum(sumf);
 
         if (tiisg == 0) {
-            for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
-            dst_f32[r0] = all_sum;
+            for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
+            dst_f32[r0] = sum_all;
         }
     }
 }
@@ -2426,9 +2555,9 @@ kernel void kernel_mul_mv_l4(
             sumf += dot((float4) x4[i], y4[i]);
         }
 
-        float all_sum = simd_sum(sumf);
+        float sum_all = simd_sum(sumf);
         if (tiisg == 0) {
-            dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+            dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
         }
     }
 }
@@ -2596,17 +2725,7 @@ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_
 typedef void (im2col_t)(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2616,17 +2735,7 @@ template 
 kernel void kernel_im2col(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2647,17 +2756,17 @@ kernel void kernel_im2col(
     const int64_t ioh = tgpig[1];
     const int64_t iow = tgpig[2];
 
-    const int64_t iiw = iow*s0 + ikw*d0 - p0;
-    const int64_t iih = ioh*s1 + ikh*d1 - p1;
+    const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
+    const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
 
-    const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
+    const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
 
     device T * pdst = (device T *) (dst);
 
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
         pdst[offset_dst] = 0.0f;
     } else {
-        const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
+        const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
         pdst[offset_dst] = x[offset_src];
     }
 }
@@ -2668,20 +2777,7 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col;
 typedef void (im2col_ext_t)(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
-        constant   int32_t & N,
-        constant   int32_t & KH,
-        constant   int32_t & KW,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2691,53 +2787,40 @@ template 
 kernel void kernel_im2col_ext(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
-        constant   int32_t & N,
-        constant   int32_t & KH,
-        constant   int32_t & KW,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
-    const int64_t KHW = KH * KW;             // KHW == ntg[1] * ntg[2], KW == ntg[2]
+    const int64_t KHW = (int64_t)args.KHW;
 
-    const int64_t d = tgpig[0] / CHW;
-    const int64_t chw = tgpig[0] % CHW;
+    const int64_t d = tgpig[0] / args.CHW;
+    const int64_t chw = tgpig[0] % args.CHW;
     const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
     const int64_t HW = tgpig[0] % KHW;
 
     const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
-    if (tpitg_0 >= N) {
+    if (tpitg_0 >= args.N) {
         return;
     }
 
-    const int64_t tpitg_1 = HW / KW;
-    const int64_t tpitg_2 = HW % KW;
+    const int64_t tpitg_1 = HW / args.KW;
+    const int64_t tpitg_2 = HW % args.KW;
 
-    const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
-    const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
+    const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
+    const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
 
     const int64_t offset_dst =
-        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
-        (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
+        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
+        (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
 
     device T * pdst = (device T *) (dst);
 
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
         pdst[offset_dst] = 0.0f;
     } else {
-        const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
-        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+        const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
+        pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
     }
 }
 
@@ -2748,12 +2831,7 @@ typedef void (conv_transpose_1d_t)(
         device const float * src0,
         device const float * src1,
         device        char * dst,
-        constant   int32_t & IC,
-        constant   int32_t & IL,
-        constant   int32_t & K,
-        constant   int32_t & s0,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
+        constant ggml_metal_kargs_conv_transpose_1d & args,
         uint3   tgpig[[threadgroup_position_in_grid]],
         uint3    tgpg[[threadgroups_per_grid]]);
 
@@ -2762,29 +2840,24 @@ kernel void kernel_conv_transpose_1d(
         device const     T * src0,
         device const float * src1,
         device        char * dst,
-        constant   int32_t & IC,
-        constant   int32_t & IL,
-        constant   int32_t & K,
-        constant   int32_t & s0,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
+        constant ggml_metal_kargs_conv_transpose_1d & args,
         uint3   tgpig[[threadgroup_position_in_grid]],
         uint3   tgpg[[threadgroups_per_grid]]) {
 
     float v = 0.0f;
 
-    for (int64_t c = 0; c < IC; c++) {
-        const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
-        const int32_t input_offset = c * IL;
+    for (int64_t c = 0; c < args.IC; c++) {
+        const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
+        const int32_t input_offset = c * args.IL;
 
-        for (int64_t i = 0; i < IL; i++) {
-            if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
-                v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
+        for (int64_t i = 0; i < args.IL; i++) {
+            if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
+                v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
             }
         }
     }
 
-    device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
+    device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
 
     dst_ptr[0] = v;
 }
@@ -2794,12 +2867,7 @@ kernel void kernel_conv_transpose_1d(
     device const float * src0,
     device const float * src1,
     device        char * dst,
-    constant   int32_t & IC,
-    constant   int32_t & IL,
-    constant   int32_t & K,
-    constant   int32_t & s0,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
+    constant ggml_metal_kargs_conv_transpose_1d & args,
     uint3   tgpig[[threadgroup_position_in_grid]],
     uint3    tgpg[[threadgroups_per_grid]]);
 
@@ -2808,38 +2876,14 @@ kernel void kernel_conv_transpose_1d(
     device const half  * src0,
     device const float * src1,
     device        char * dst,
-    constant   int32_t & IC,
-    constant   int32_t & IL,
-    constant   int32_t & K,
-    constant   int32_t & s0,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
+    constant ggml_metal_kargs_conv_transpose_1d & args,
     uint3   tgpig[[threadgroup_position_in_grid]],
     uint3    tgpg[[threadgroups_per_grid]]);
 
 kernel void kernel_upscale_f32(
     device  const char * src0,
     device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant   int64_t & ne0,
-    constant   int64_t & ne1,
-    constant   int64_t & ne2,
-    constant   int64_t & ne3,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
-    constant     float & sf0,
-    constant     float & sf1,
-    constant     float & sf2,
-    constant     float & sf3,
+    constant ggml_metal_kargs_upscale & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
@@ -2848,15 +2892,15 @@ kernel void kernel_upscale_f32(
     const int64_t i2 = tgpig.y;
     const int64_t i1 = tgpig.x;
 
-    const int64_t i03 = i3/sf3;
-    const int64_t i02 = i2/sf2;
-    const int64_t i01 = i1/sf1;
+    const int64_t i03 = i3/args.sf3;
+    const int64_t i02 = i2/args.sf2;
+    const int64_t i01 = i1/args.sf1;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int64_t i00 = i0/sf0;
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int64_t i00 = i0/args.sf0;
 
-        device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-        device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1  +  i0*nb0);
+        device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+        device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1  +  i0*args.nb0);
 
         dst_ptr[0] = src0_ptr[0];
     }
@@ -2865,22 +2909,7 @@ kernel void kernel_upscale_f32(
 kernel void kernel_pad_f32(
     device  const char * src0,
     device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant   int64_t & ne0,
-    constant   int64_t & ne1,
-    constant   int64_t & ne2,
-    constant   int64_t & ne3,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
+    constant ggml_metal_kargs_pad & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
@@ -2893,12 +2922,12 @@ kernel void kernel_pad_f32(
     const int64_t i02 = i2;
     const int64_t i01 = i1;
 
-    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
-    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1);
 
-    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
-        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-            if (i0 < ne00) {
+    if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            if (i0 < args.ne00) {
                 dst_ptr[i0] = src0_ptr[i0];
             } else {
                 dst_ptr[i0] = 0.0f;
@@ -2908,7 +2937,7 @@ kernel void kernel_pad_f32(
         return;
     }
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
         dst_ptr[i0] = 0.0f;
     }
 }
@@ -2916,21 +2945,7 @@ kernel void kernel_pad_f32(
 kernel void kernel_pad_reflect_1d_f32(
     device  const char * src0,
     device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant   int64_t & ne0,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
-    constant   int32_t & p0,
-    constant   int32_t & p1,
+    constant   ggml_metal_kargs_pad_reflect_1d & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3  tgpg[[threadgroups_per_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2944,17 +2959,17 @@ kernel void kernel_pad_reflect_1d_f32(
     const int64_t i02 = i2;
     const int64_t i01 = i1;
 
-    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
-    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1);
 
-    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
-        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-            if (i0 < p0) {
-                dst_ptr[i0] = src0_ptr[p0 - i0];
-            } else if (i0 < ne0 - p1) {
-                dst_ptr[i0] = src0_ptr[i0 - p0];
+    if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            if (i0 < args.p0) {
+                dst_ptr[i0] = src0_ptr[args.p0 - i0];
+            } else if (i0 < args.ne0 - args.p1) {
+                dst_ptr[i0] = src0_ptr[i0 - args.p0];
             } else {
-                dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
+                dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
             }
         }
     }
@@ -3007,44 +3022,40 @@ kernel void kernel_unpad_f32(
 
 kernel void kernel_arange_f32(
     device        char * dst,
-    constant   int64_t & ne0,
-    constant   float   & start,
-    constant   float   & step,
+    constant   ggml_metal_kargs_arange & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
 
     device float * dst_ptr = (device float *) dst;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        dst_ptr[i0] = start + step * i0;
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        dst_ptr[i0] = args.start + args.step * i0;
     }
 }
 
 kernel void kernel_timestep_embedding_f32(
     device  const char * src0,
     device        char * dst,
-    constant  uint64_t & nb1,
-    constant  int      & dim,
-    constant  int      & max_period,
+    constant  ggml_metal_kargs_timestep_embedding & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
 
     int i = tgpig.x;
-    device float * embed_data = (device float *)(dst +  i*nb1);
+    device float * embed_data = (device float *)(dst + i*args.nb1);
 
-    int half_ = dim / 2;
+    int half_ = args.dim / 2;
     for (int j = tpitg.x; j < half_; j += ntg.x) {
         float timestep = ((device float *)src0)[i];
-        float freq = (float)exp(-log((float)max_period) * j / half_);
+        float freq = (float)exp(-log((float)args.max_period) * j / half_);
         float arg = timestep * freq;
         embed_data[j        ] = cos(arg);
         embed_data[j + half_] = sin(arg);
     }
 
-    if (dim % 2 != 0 && tpitg.x == 0) {
-        embed_data[dim] = 0.f;
+    if (args.dim % 2 != 0 && tpitg.x == 0) {
+        embed_data[args.dim] = 0.f;
     }
 }
 
@@ -3052,8 +3063,7 @@ kernel void kernel_timestep_embedding_f32(
 typedef void (argsort_t)(
         device const float  * x,
         device     int32_t  * dst,
-        constant   int64_t  & ncols,
-        constant   int64_t  & ncols_pad,
+        constant   ggml_metal_kargs_argsort & args,
         threadgroup int32_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]]);
@@ -3062,8 +3072,7 @@ template
 kernel void kernel_argsort_f32_i32(
         device const float   * x,
         device       int32_t * dst,
-        constant     int64_t & ncols,
-        constant     int64_t & ncols_pad,
+        constant   ggml_metal_kargs_argsort & args,
         threadgroup int32_t  * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]]) {
@@ -3071,9 +3080,9 @@ kernel void kernel_argsort_f32_i32(
     int col = tpitg[0];
     int row = tgpig[1];
 
-    if (col >= ncols_pad) return;
+    if (col >= args.ncols_pad) return;
 
-    device const float   * x_row   = x + row * ncols;
+    device const float   * x_row   = x + row * args.ncols;
     threadgroup int32_t  * dst_row = shared_values;
 
     // initialize indices
@@ -3081,21 +3090,21 @@ kernel void kernel_argsort_f32_i32(
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    for (int k = 2; k <= ncols_pad; k *= 2) {
+    for (int k = 2; k <= args.ncols_pad; k *= 2) {
         for (int j = k / 2; j > 0; j /= 2) {
             int ixj = col ^ j;
             if (ixj > col) {
                 if ((col & k) == 0) {
-                    if (dst_row[col] >= ncols ||
-                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                    if (dst_row[col] >= args.ncols ||
+                        (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
                             x_row[dst_row[col]] > x_row[dst_row[ixj]] :
                             x_row[dst_row[col]] < x_row[dst_row[ixj]]))
                     ) {
                         SWAP(dst_row[col], dst_row[ixj]);
                     }
                 } else {
-                    if (dst_row[ixj] >= ncols ||
-                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                    if (dst_row[ixj] >= args.ncols ||
+                        (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
                             x_row[dst_row[col]] < x_row[dst_row[ixj]] :
                             x_row[dst_row[col]] > x_row[dst_row[ixj]]))
                     ) {
@@ -3108,8 +3117,8 @@ kernel void kernel_argsort_f32_i32(
     }
 
     // copy the result to dst without the padding
-    if (col < ncols) {
-        dst[row * ncols + col] = dst_row[col];
+    if (col < args.ncols) {
+        dst[row * args.ncols + col] = dst_row[col];
     }
 }
 
@@ -3119,9 +3128,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar
 kernel void kernel_leaky_relu_f32(
         device const float * src0,
         device       float * dst,
-        constant     float & slope,
+        constant     ggml_metal_kargs_leaky_relu & args,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
 }
 
 // ref: https://arxiv.org/pdf/2307.08691.pdf
@@ -3148,7 +3157,8 @@ template<
     typename vd4x4_t, // key type in device memory
     short nl_v,
     void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
-    short D,         // head size
+    short DK,        // K head size
+    short DV,        // V head size
     short Q  = 8,    // queries per threadgroup
     short KV = 8,    // key/value processed per each simdgroup
     short C  = 32>   // cache items per threadgroup
@@ -3170,20 +3180,24 @@ kernel void kernel_flash_attn_ext(
     const int iq2 = tgpig[1];
     const int iq1 = tgpig[0]*Q;
 
-    const short D4  = D/4;
-    const short D8  = D/8;
-    const short D16 = D/16;
-    const short NW  = N_SIMDWIDTH;
-    const short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
+    constexpr short DK4  = DK/4;
+    constexpr short DK8  = DK/8;
+    constexpr short DK16 = DK/16;
+    constexpr short DV4  = DV/4;
+    constexpr short DV8  = DV/8;
+    constexpr short DV16 = DV/16;
+
+    constexpr short NW  = N_SIMDWIDTH;
+    constexpr short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
 
     const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
-    const short T  = D + 2*TS; // shared memory size per query in (half)
+    const short T  = DK + 2*TS; // shared memory size per query in (half)
 
-    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +              0*D); // holds the query data
-    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +              0*D); // same as above but in q4_t
-    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +              0*D); // reuse query data for accumulation
-    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +              0*D); // same as above but in o4_t
-    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
+    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +              0*DK); // holds the query data
+    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +              0*DK); // same as above but in q4_t
+    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +              0*DK); // reuse query data for accumulation
+    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +              0*DK); // same as above but in o4_t
+    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
 
     threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
     threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3192,23 +3206,23 @@ kernel void kernel_flash_attn_ext(
     threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
 
     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    o8x8_t lo[D8];
+    o8x8_t lo[DV8];
 
     // load heads from Q to shared memory
     for (short j = sgitg; j < Q; j += nsg) {
         device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
 
-        for (short i = tiisg; i < D4; i += NW) {
+        for (short i = tiisg; i < DK4; i += NW) {
             if (iq1 + j < args.ne01) {
-                sq4[j*D4 + i] = (q4_t) q4[i];
+                sq4[j*DK4 + i] = (q4_t) q4[i];
             } else {
-                sq4[j*D4 + i] = (q4_t) 0.0f;
+                sq4[j*DK4 + i] = (q4_t) 0.0f;
             }
         }
     }
 
     // zero out lo
-    for (short i = 0; i < D8; ++i) {
+    for (short i = 0; i < DV8; ++i) {
         lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f);
     }
 
@@ -3222,8 +3236,8 @@ kernel void kernel_flash_attn_ext(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     {
-        half S[Q] = { [0 ... Q-1] = 0.0f };
-        half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
+        float S[Q] = { [0 ... Q-1] = 0.0f };
+        float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
 
         // thread indices inside the simdgroup
         // TODO: see if we can utilize quad-group functions for better performance
@@ -3238,22 +3252,15 @@ kernel void kernel_flash_attn_ext(
         const short ikv2 = iq2/(args.ne02/args.ne_12_2);
         const short ikv3 = iq3/(args.ne03/args.ne_12_3);
 
-        // load the queries from shared memory into local memory
-        q8x8_t mq[D8];
-
-        for (short i = 0; i < D8; ++i) {
-            simdgroup_load(mq[i], sq + i*8, D);
-        }
-
         const bool has_mask = mask != q;
 
-        half slope = 1.0f;
+        float slope = 1.0f;
 
         // ALiBi
         if (args.max_bias > 0.0f) {
             const short h = iq2;
 
-            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const float base = h < args.n_head_log2 ? args.m0 : args.m1;
             const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
             slope = pow(base, exph);
@@ -3269,14 +3276,14 @@ kernel void kernel_flash_attn_ext(
 
             if (has_mask) {
                 // used to detect blocks full of -INF
-                half smax = -INFINITY;
+                float smax = -INFINITY;
 
                 // load the mask in shared memory
                 #pragma unroll(Q)
                 for (short j = 0; j < Q; ++j) {
                     device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
 
-                    const half m = pm[ic + tiisg];
+                    const float m = pm[ic + tiisg];
 
                     ss[j*TS + C + tiisg] = m;
                     smax = max(smax, m);
@@ -3297,20 +3304,22 @@ kernel void kernel_flash_attn_ext(
                     // this is compile-time check, so it does not have runtime overhead
                     if (is_same::value) {
                         // we can read directly from global memory
-                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
 
-                        #pragma unroll(D8)
-                        for (short i = 0; i < D8; ++i) {
+                        #pragma unroll(DK8)
+                        for (short i = 0; i < DK8; ++i) {
                             k8x8_t mk;
-                            simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+                            simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
 
-                            simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+                            q8x8_t mq;
+                            simdgroup_load(mq, sq + i*8, DK);
+                            simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
                         }
                     } else {
-                        for (short ii = 0; ii < D16; ii += 4) {
-                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        for (short ii = 0; ii < DK16; ii += 4) {
+                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
 
-                            if (D16%4 == 0) {
+                            if (DK16%4 == 0) {
                                 // the head is evenly divisible by 4*16 = 64, so no need for bound checks
                                 {
                                     k4x4_t tmp;
@@ -3323,15 +3332,18 @@ kernel void kernel_flash_attn_ext(
                                 #pragma unroll(4)
                                 for (short k = 0; k < 4; ++k) {
                                     k8x8_t mk;
+                                    q8x8_t mq;
 
                                     simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
 
                                     simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
                                 }
                             } else {
-                                if (ii + tx < D16) {
+                                if (ii + tx < DK16) {
                                     k4x4_t tmp;
                                     deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
                                     sk4x4[4*ty + tx] = tmp;
@@ -3339,14 +3351,17 @@ kernel void kernel_flash_attn_ext(
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
-                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                for (short k = 0; k < 4 && ii + k < DK16; ++k) {
                                     k8x8_t mk;
+                                    q8x8_t mq;
 
                                     simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
 
                                     simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
                                 }
                             }
                         }
@@ -3364,10 +3379,10 @@ kernel void kernel_flash_attn_ext(
             // online softmax
             {
                 for (ushort j = 0; j < Q; ++j) {
-                    const half m = M[j];
+                    const float m = M[j];
 
                     // scale and apply the logitcap / mask
-                    half s = ss[j*TS + tiisg]*args.scale;
+                    float s = ss[j*TS + tiisg]*args.scale;
 
                     if (args.logit_softcap != 0.0f) {
                         s = args.logit_softcap*precise::tanh(s);
@@ -3378,8 +3393,8 @@ kernel void kernel_flash_attn_ext(
 
                     M[j] = simd_max(max(M[j], s));
 
-                    const half ms = exp(m - M[j]);
-                    const half vs = exp(s - M[j]);
+                    const float ms = exp(m - M[j]);
+                    const float vs = exp(s - M[j]);
 
                     S[j] = S[j]*ms + simd_sum(vs);
 
@@ -3398,8 +3413,8 @@ kernel void kernel_flash_attn_ext(
                 s8x8_t mm;
                 simdgroup_load(mm, ss + 2*C, TS, 0, false);
 
-                #pragma unroll(D8)
-                for (short i = 0; i < D8; ++i) {
+                #pragma unroll(DV8)
+                for (short i = 0; i < DV8; ++i) {
                     simdgroup_multiply(lo[i], mm, lo[i]);
                 }
             }
@@ -3412,20 +3427,20 @@ kernel void kernel_flash_attn_ext(
 
                     if (is_same::value) {
                         // we can read directly from global memory
-                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
 
-                        #pragma unroll(D8)
-                        for (short i = 0; i < D8; ++i) {
+                        #pragma unroll(DV8)
+                        for (short i = 0; i < DV8; ++i) {
                             v8x8_t mv;
-                            simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
+                            simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
 
                             simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
                         }
                     } else {
-                        for (short ii = 0; ii < D16; ii += 4) {
-                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        for (short ii = 0; ii < DV16; ii += 4) {
+                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
 
-                            if (D16%4 == 0) {
+                            if (DV16%4 == 0) {
                                 // no need for bound checks
                                 {
                                     v4x4_t tmp;
@@ -3446,7 +3461,7 @@ kernel void kernel_flash_attn_ext(
                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
                                 }
                             } else {
-                                if (ii + tx < D16) {
+                                if (ii + tx < DV16) {
                                     v4x4_t tmp;
                                     deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
                                     sv4x4[4*ty + tx] = tmp;
@@ -3454,7 +3469,7 @@ kernel void kernel_flash_attn_ext(
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
-                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                for (short k = 0; k < 4 && ii + k < DV16; ++k) {
                                     v8x8_t mv;
 
                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
@@ -3481,15 +3496,15 @@ kernel void kernel_flash_attn_ext(
 
     // reduce the warps sequentially
     for (ushort sg = 1; sg < nsg; ++sg) {
-        half S = { 0.0f };
-        half M = { -__FLT16_MAX__/2 };
+        float S = { 0.0f };
+        float M = { -__FLT16_MAX__/2 };
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         // each simdgroup stores its output to shared memory, reusing sq
         if (sgitg == sg) {
-            for (short i = 0; i < D8; ++i) {
-                simdgroup_store(lo[i], so + i*8, D, 0, false);
+            for (short i = 0; i < DV8; ++i) {
+                simdgroup_store(lo[i], so + i*8, DV, 0, false);
             }
         }
 
@@ -3498,16 +3513,16 @@ kernel void kernel_flash_attn_ext(
         // the first simdgroup accumulates the results from the other simdgroups
         if (sgitg == 0) {
             for (short j = 0; j < Q; ++j) {
-                const half S0 = ss[j*TS +         0];
-                const half S1 = ss[j*TS + sg*SH + 0];
+                const float S0 = ss[j*TS +         0];
+                const float S1 = ss[j*TS + sg*SH + 0];
 
-                const half M0 = ss[j*TS +         1];
-                const half M1 = ss[j*TS + sg*SH + 1];
+                const float M0 = ss[j*TS +         1];
+                const float M1 = ss[j*TS + sg*SH + 1];
 
                 M = max(M0, M1);
 
-                const half ms0 = exp(M0 - M);
-                const half ms1 = exp(M1 - M);
+                const float ms0 = exp(M0 - M);
+                const float ms1 = exp(M1 - M);
 
                 S = S0*ms0 + S1*ms1;
 
@@ -3528,11 +3543,11 @@ kernel void kernel_flash_attn_ext(
                 simdgroup_load(ms0, ss + 2*C,         TS, 0, false);
                 simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
 
-                #pragma unroll(D8)
-                for (short i = 0; i < D8; ++i) {
+                #pragma unroll(DV8)
+                for (short i = 0; i < DV8; ++i) {
                     o8x8_t t;
 
-                    simdgroup_load    (t, so + i*8, D, 0, false);
+                    simdgroup_load    (t, so + i*8, DV, 0, false);
                     simdgroup_multiply(t, ms1, t);
 
                     simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@@ -3543,8 +3558,8 @@ kernel void kernel_flash_attn_ext(
 
     // store result to shared memory (reuse sq)
     if (sgitg == 0) {
-        for (short i = 0; i < D8; ++i) {
-            simdgroup_store(lo[i], so + i*8, D, 0, false);
+        for (short i = 0; i < DV8; ++i) {
+            simdgroup_store(lo[i], so + i*8, DV, 0, false);
         }
     }
 
@@ -3555,8 +3570,8 @@ kernel void kernel_flash_attn_ext(
         for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
             const float S = ss[j*TS + 0];
 
-            for (short i = tiisg; i < D4; i += NW) {
-                dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
+            for (short i = tiisg; i < DV4; i += NW) {
+                dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
             }
         }
     }
@@ -3573,80 +3588,94 @@ kernel void kernel_flash_attn_ext(
     float,          simdgroup_float8x8, \
     half,  half4,   simdgroup_half8x8
 
-typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t;
+typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t;
 
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h192")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 #endif
 
-template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 
-template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 
-template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 
-template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 
-template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #undef FA_TYPES
 
 template<
-    typename q4_t,    // query types in shared memory
-    typename q4x4_t,
-    typename k4x4_t,  // key types in shared memory
-    typename v4x4_t,  // value types in shared memory
-    typename qk_t,    // Q*K types
-    typename s_t,     // soft-max types
+    typename q4_t,  // query types in shared memory
+    typename k4_t,  // key types in shared memory
+    typename v4_t,  // value types in shared memory
+    typename qk_t,  // Q*K types
+    typename s_t,   // soft-max types
     typename s4_t,
-    typename s4x4_t,
-    typename o4x4_t,  // attention accumulation types
-    typename kd4x4_t, // key type in device memory
+    typename o4_t,  // attention accumulation types
+    typename kd4_t, // key type in device memory
     short nl_k,
-    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
-    typename vd4x4_t, // key type in device memory
+    void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
+    typename vd4_t, // key type in device memory
     short nl_v,
-    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
-    short D,         // head size
-    short Q  = 1,    // queries per threadgroup
-    short C  = 32>   // cache items per threadgroup
+    void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
+    short DK,       // K head size
+    short DV,       // V head size
+    short NE = 4,   // head elements per thread
+    short Q  = 1,   // queries per threadgroup
+    short C  = 32>  // cache items per threadgroup
 kernel void kernel_flash_attn_ext_vec(
         constant ggml_metal_kargs_flash_attn_ext & args,
         device const char * q,
@@ -3665,29 +3694,28 @@ kernel void kernel_flash_attn_ext_vec(
     const int iq2 = tgpig[1];
     const int iq1 = tgpig[0];
 
-    const short D4  = D/4;
-    const short D16 = D/16;
-    const short NW  = N_SIMDWIDTH;
-    const short NL  = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
-    const short SH  = 2*C;  // shared memory per simdgroup
+    constexpr short DK4 = DK/4;
+    constexpr short DV4 = DV/4;
+    constexpr short NW  = N_SIMDWIDTH;
+    constexpr short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
+    constexpr short SH  = 4*C;   // shared memory per simdgroup
 
-    const short T = D + nsg*SH; // shared memory size per query in (half)
+    const short T = DK + nsg*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t    * sq    = (threadgroup q_t    *) (shmem_f16 +                0*D); // holds the query data
-    threadgroup q4_t   * sq4   = (threadgroup q4_t   *) (shmem_f16 +                0*D); // same as above but in q4_t
-    threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 +                0*D); // same as above but in q4x4_t
-    threadgroup s_t    * ss    = (threadgroup s_t    *) (shmem_f16 + sgitg*SH     + Q*D); // scratch buffer for attention
-    threadgroup s4_t   * ss4   = (threadgroup s4_t   *) (shmem_f16 + sgitg*SH     + Q*D); // same as above but in s4_t
-    threadgroup half   * sm    = (threadgroup half   *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
-    threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D      + Q*T); // scratch buffer for the results
+  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                  0*DK); // holds the query data
+    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                  0*DK); // same as above but in q4_t
+    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 + sgitg*SH       + Q*DK); // scratch buffer for attention
+    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 + sgitg*SH       + Q*DK); // same as above but in s4_t
+    threadgroup float * sm  = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
+    threadgroup o4_t  * sr4 = (threadgroup o4_t  *) (shmem_f16 + sgitg*DV       + Q*T);  // scratch buffer for the results
 
-    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    o4x4_t lo[D16/NL];
+    // store the result for all queries in local memory (the O matrix from the paper)
+    o4_t lo[DV4/NL];
 
     // load heads from Q to shared memory
     device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
 
-    for (short i = tiisg; i < D4; i += NW) {
+    for (short i = tiisg; i < DK4; i += NW) {
         if (iq1 < args.ne01) {
             sq4[i] = (q4_t) q4[i];
         } else {
@@ -3696,8 +3724,8 @@ kernel void kernel_flash_attn_ext_vec(
     }
 
     // zero out lo
-    for (short i = 0; i < D16/NL; ++i) {
-        lo[i] = (o4x4_t) 0.0f;
+    for (short i = 0; i < DV4/NL; ++i) {
+        lo[i] = (o4_t) 0.0f;
     }
 
     // zero out shared memory SH
@@ -3708,8 +3736,8 @@ kernel void kernel_flash_attn_ext_vec(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     {
-        half S = 0.0f;
-        half M = -__FLT16_MAX__/2;
+        float S = 0.0f;
+        float M = -__FLT16_MAX__/2;
 
         // thread indices inside the simdgroup
         const short tx = tiisg%NL;
@@ -3722,26 +3750,18 @@ kernel void kernel_flash_attn_ext_vec(
         const short ikv2 = iq2/(args.ne02/args.ne_12_2);
         const short ikv3 = iq3/(args.ne03/args.ne_12_3);
 
-        // load the queries from shared memory into local memory
-        q4x4_t mq[D16/NL];
-
-        #pragma unroll(D16/NL)
-        for (short ii = 0; ii < D16; ii += NL) {
-            mq[ii/NL] = sq4x4[ii + tx];
-        }
-
         const bool has_mask = mask != q;
 
         // pointer to the mask
         device const half * pm = (device const half *) (mask + iq1*args.nb31);
 
-        half slope = 1.0f;
+        float slope = 1.0f;
 
         // ALiBi
         if (args.max_bias > 0.0f) {
             const short h = iq2;
 
-            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const float base = h < args.n_head_log2 ? args.m0 : args.m1;
             const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
             slope = pow(base, exph);
@@ -3761,43 +3781,56 @@ kernel void kernel_flash_attn_ext_vec(
 
             // Q*K^T
             {
-                // each simdgroup processes 1 query and 4 (NW/NL) keys
-                for (short cc = 0; cc < C/4; ++cc) {
-                    qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
+                // each simdgroup processes 1 query and NE (NW/NL) head elements
+                for (short cc = 0; cc < C/NE; ++cc) {
+                    qk_t mqk = 0.0f;
 
-                    device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                    device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
 
-                    #pragma unroll(D16/NL)
-                    for (short ii = 0; ii < D16; ii += NL) {
+                    #pragma unroll(DK4/NL)
+                    for (short ii = 0; ii < DK4; ii += NL) {
                         const short i = ii + tx;
 
-                        k4x4_t mk;
-                        deq_k(pk + i/nl_k, i%nl_k, mk);
+                        k4_t mk;
+                        deq_k_t4(pk + i/nl_k, i%nl_k, mk);
 
                         // note: this is less precise than the version below
-                        //mqka[0] += dot(mq[ii/NL][0], mk[0]);
-                        //mqka[1] += dot(mq[ii/NL][1], mk[1]);
-                        //mqka[2] += dot(mq[ii/NL][2], mk[2]);
-                        //mqka[3] += dot(mq[ii/NL][3], mk[3]);
+                        //mqka[0] += dot(mq[0], mk[0]);
+                        //mqka[1] += dot(mq[1], mk[1]);
+                        //mqka[2] += dot(mq[2], mk[2]);
+                        //mqka[3] += dot(mq[3], mk[3]);
 
-                        mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
-                        mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
-                        mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
-                        mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
+                        //q4x4_t mq = sq4x4[i];
+                        //mqka[0] += dot((float4) mq[0], (float4) mk[0]);
+                        //mqka[1] += dot((float4) mq[1], (float4) mk[1]);
+                        //mqka[2] += dot((float4) mq[2], (float4) mk[2]);
+                        //mqka[3] += dot((float4) mq[3], (float4) mk[3]);
+
+                        mqk += dot((float4) mk, (float4) sq4[i]);
                     }
 
-                    qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
+                    static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
 
-                    // simdgroup reduce
+                    // simdgroup reduce (NE = 4)
                     // [ 0 ..  7] -> [ 0]
                     // [ 8 .. 15] -> [ 8]
                     // [16 .. 23] -> [16]
                     // [24 .. 31] -> [24]
-                  //mqk += simd_shuffle_down(mqk, 16);
-                  //mqk += simd_shuffle_down(mqk,  8);
-                    mqk += simd_shuffle_down(mqk,  4);
-                    mqk += simd_shuffle_down(mqk,  2);
-                    mqk += simd_shuffle_down(mqk,  1);
+                    if (NE <= 1) {
+                        mqk += simd_shuffle_down(mqk, 16);
+                    }
+                    if (NE <= 2) {
+                        mqk += simd_shuffle_down(mqk,  8);
+                    }
+                    if (NE <= 4) {
+                        mqk += simd_shuffle_down(mqk,  4);
+                    }
+                    if (NE <= 8) {
+                        mqk += simd_shuffle_down(mqk,  2);
+                    }
+                    if (NE <= 16) {
+                        mqk += simd_shuffle_down(mqk,  1);
+                    }
 
                     // mqk = mqk*scale + mask*slope
                     if (tx == 0) {
@@ -3807,9 +3840,9 @@ kernel void kernel_flash_attn_ext_vec(
                             mqk = args.logit_softcap*precise::tanh(mqk);
                         }
 
-                        mqk += sm[4*cc + ty]*slope;
+                        mqk += sm[NE*cc + ty]*slope;
 
-                        ss[4*cc + ty] = mqk;
+                        ss[NE*cc + ty] = mqk;
                     }
                 }
             }
@@ -3818,13 +3851,13 @@ kernel void kernel_flash_attn_ext_vec(
 
             // online softmax
             {
-                const half m = M;
-                const half s = ss[tiisg];
+                const float m = M;
+                const float s = ss[tiisg];
 
                 M = simd_max(max(M, s));
 
-                const half ms = exp(m - M);
-                const half vs = exp(s - M);
+                const float ms = exp(m - M);
+                const float vs = exp(s - M);
 
                 S = S*ms + simd_sum(vs);
 
@@ -3832,8 +3865,8 @@ kernel void kernel_flash_attn_ext_vec(
                 ss[tiisg] = vs;
 
                 // O = diag(ms)*O
-                #pragma unroll(D16/NL)
-                for (short ii = 0; ii < D16; ii += NL) {
+                #pragma unroll(DV4/NL)
+                for (short ii = 0; ii < DV4; ii += NL) {
                     lo[ii/NL] *= ms;
                 }
             }
@@ -3842,19 +3875,20 @@ kernel void kernel_flash_attn_ext_vec(
 
             // O = O + (Q*K^T)*V
             {
-                for (short cc = 0; cc < C/4; ++cc) {
-                    device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                //#pragma unroll(C/NE)
+                for (short cc = 0; cc < C/NE; ++cc) {
+                    device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
 
-                    const s4x4_t ms(ss[4*cc + ty]);
+                    const s4_t ms(ss[NE*cc + ty]);
 
-                    #pragma unroll(D16/NL)
-                    for (short ii = 0; ii < D16; ii += NL) {
+                    #pragma unroll(DV4/NL)
+                    for (short ii = 0; ii < DV4; ii += NL) {
                         const short i = ii + tx;
 
-                        v4x4_t mv;
-                        deq_v(pv4 + i/nl_v, i%nl_v, mv);
+                        v4_t mv;
+                        deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
 
-                        lo[ii/NL] += mv*ms;
+                        lo[ii/NL] += o4_t(float4(mv)*float4(ms));
                     }
                 }
             }
@@ -3867,7 +3901,7 @@ kernel void kernel_flash_attn_ext_vec(
         }
     }
 
-    // simdgroup reduce
+    // simdgroup reduce (NE = 4)
     // [ 0,  8, 16, 24] -> [ 0]
     // [ 1,  9, 17, 25] -> [ 1]
     // [ 2, 10, 18, 26] -> [ 2]
@@ -3876,37 +3910,48 @@ kernel void kernel_flash_attn_ext_vec(
     // [ 5, 13, 21, 29] -> [ 5]
     // [ 6, 14, 22, 30] -> [ 6]
     // [ 7, 15, 23, 31] -> [ 7]
-    for (short ii = 0; ii < D16; ii += NL) {
-        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
-        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
+    for (short ii = 0; ii < DV4; ii += NL) {
+        if (NE > 1) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
+        }
 
-        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
-        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
+        if (NE > 2) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
+        }
 
-        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
-        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
+        if (NE > 4) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
+        }
 
-        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
-        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
+        if (NE > 8) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
+        }
+
+        if (NE > 16) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
+        }
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     // store results to shared memory
-    for (short i = tiisg; i < D16; i += NL) {
-        sr4x4[i] = lo[i/NL];
+    for (short i = tiisg; i < DV4; i += NL) {
+        sr4[i] = lo[i/NL];
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3914,18 +3959,18 @@ kernel void kernel_flash_attn_ext_vec(
     // parallel reduce
     for (short r = nsg/2; r > 0; r >>= 1) {
         if (sgitg < r) {
-            const half S0 = ss[       0];
-            const half S1 = ss[r*SH + 0];
+            const float S0 = ss[           0];
+            const float S1 = ss[r*(SH/2) + 0];
 
-            const half M0 = ss[       1];
-            const half M1 = ss[r*SH + 1];
+            const float M0 = ss[           1];
+            const float M1 = ss[r*(SH/2) + 1];
 
-            const half M = max(M0, M1);
+            const float M = max(M0, M1);
 
-            const half ms0 = exp(M0 - M);
-            const half ms1 = exp(M1 - M);
+            const float ms0 = exp(M0 - M);
+            const float ms1 = exp(M1 - M);
 
-            const half S = S0*ms0 + S1*ms1;
+            const float S = S0*ms0 + S1*ms1;
 
             if (tiisg == 0) {
                 ss[0] = S;
@@ -3933,22 +3978,22 @@ kernel void kernel_flash_attn_ext_vec(
             }
 
             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
-            for (short i = tiisg; i < D16; i += NW) {
-                sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
+            for (short i = tiisg; i < DV4; i += NW) {
+                sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
             }
         }
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    device float4x4 * dst44 = (device float4x4 *) dst;
+    device float4 * dst4 = (device float4 *) dst;
 
     // final rescale with 1/S and store to global memory
     if (sgitg == 0) {
         const float S = ss[0];
 
-        for (short i = tiisg; i < D16; i += NW) {
-            dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
+        for (short i = tiisg; i < DV4; i += NW) {
+            dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
         }
     }
 }
@@ -3957,34 +4002,54 @@ kernel void kernel_flash_attn_ext_vec(
 //       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
 //
 #define FA_TYPES \
-           half4,  half4x4, \
-                   half4x4, \
-                   half4x4, \
-    float,                  \
-    half,  half4,  half4x4, \
-                   half4x4
+           half4,  \
+           half4,  \
+           half4,  \
+    float,         \
+    float, float4, \
+           half4
 
-typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;
+typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 
 #undef FA_TYPES
 
@@ -4359,7 +4424,7 @@ kernel void kernel_cpy_f32_iq4_nl(
         float amax = 0.0f; // absolute max
         float max  = 0.0f;
 
-        for (int j = 0; j < QK4_0; j++) {
+        for (int j = 0; j < QK4_NL; j++) {
             const float v = src[j];
             if (amax < fabs(v)) {
                 amax = fabs(v);
@@ -4467,7 +4532,7 @@ kernel void kernel_concat(
     }
 }
 
-template
+template
 void kernel_mul_mv_q2_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -4483,7 +4548,7 @@ void kernel_mul_mv_q2_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -4495,20 +4560,19 @@ void kernel_mul_mv_q2_K_f32_impl(
     device const float      * y = (device const float      *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
-    const int ix = tiisg/8;  // 0...3
-    const int it = tiisg%8;  // 0...7
-    const int iq = it/4;     // 0 or 1
-    const int ir = it%4;     // 0...3
-    const int is = (8*ir)/16;// 0 or 1
+    const short ix = tiisg/8;  // 0...3
+    const short it = tiisg%8;  // 0...7
+    const short iq = it/4;     // 0 or 1
+    const short ir = it%4;     // 0...3
+    const short is = (8*ir)/16;// 0 or 1
 
     device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
 
     for (int ib = ix; ib < nb; ib += 4) {
-
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int i = 0; i < 8; ++i) {
+        for (short i = 0; i < 8; ++i) {
             yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
             yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
             yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
@@ -4519,7 +4583,7 @@ void kernel_mul_mv_q2_K_f32_impl(
         device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
         device const half     * dh = &x[ib].d;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
             for (int i = 0; i < 8; i += 2) {
@@ -4550,10 +4614,10 @@ void kernel_mul_mv_q2_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
@@ -4568,10 +4632,10 @@ kernel void kernel_mul_mv_q2_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_q3_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -4588,7 +4652,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -4604,13 +4668,12 @@ void kernel_mul_mv_q3_K_f32_impl(
     //const uint16_t kmask1 = 0x3030;
     //const uint16_t kmask2 = 0x0f0f;
 
-    const int tid = tiisg/4;
-    const int ix  = tiisg%4;
-    const int ip  = tid/4;          // 0 or 1
-    const int il  = 2*((tid%4)/2);  // 0 or 2
-    const int ir  = tid%2;
-    const int n   = 8;
-    const int l0  = n*ir;
+    const short tid = tiisg/4;
+    const short ix  = tiisg%4;
+    const short ip  = tid/4;          // 0 or 1
+    const short il  = 2*((tid%4)/2);  // 0 or 2
+    const short ir  = tid%2;
+    const short l0  = 8*ir;
 
     // One would think that the Metal compiler would figure out that ip and il can only have
     // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
@@ -4635,8 +4698,8 @@ void kernel_mul_mv_q3_K_f32_impl(
     const uint16_t s_shift1 = 4*ip;
     const uint16_t s_shift2 = s_shift1 + il;
 
-    const int q_offset = 32*ip + l0;
-    const int y_offset = 128*ip + 32*il + l0;
+    const short q_offset = 32*ip + l0;
+    const short y_offset = 128*ip + 32*il + l0;
 
     device const float * y1 = yy + ix*QK_K + y_offset;
 
@@ -4644,10 +4707,11 @@ void kernel_mul_mv_q3_K_f32_impl(
     thread uint16_t * scales16 = (thread uint16_t *)&scales32;
     thread const int8_t * scales = (thread const int8_t *)&scales32;
 
-    float sumf1[2] = {0.f};
-    float sumf2[2] = {0.f};
+    float sumf1[nr0] = {0.f};
+    float sumf2[nr0] = {0.f};
+
     for (int i = ix; i < nb; i += 4) {
-        for (int l = 0; l < 8; ++l) {
+        for (short l = 0; l < 8; ++l) {
             yl[l+ 0] = y1[l+ 0];
             yl[l+ 8] = y1[l+16];
             yl[l+16] = y1[l+32];
@@ -4659,7 +4723,7 @@ void kernel_mul_mv_q3_K_f32_impl(
         device const uint16_t * a = (device const uint16_t *)(x[i].scales);
         device const half * dh = &x[i].d;
 
-        for (int row = 0; row < 2; ++row) {
+        for (short row = 0; row < nr0; ++row) {
             const float d_all = (float)dh[0];
 
             scales16[0] = a[4];
@@ -4670,7 +4734,7 @@ void kernel_mul_mv_q3_K_f32_impl(
             scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
 
             float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
-            for (int l = 0; l < n; l += 2) {
+            for (short l = 0; l < 8; l += 2) {
                 const int32_t qs = q[l/2];
                 s1 += yl[l+0] * (qs & qm[il/2][0]);
                 s2 += yl[l+1] * (qs & qm[il/2][1]);
@@ -4685,7 +4749,7 @@ void kernel_mul_mv_q3_K_f32_impl(
             sumf2[row] += d2 * (scales[2] - 32);
 
             s1 = s2 = s3 = s4 = s5 = s6 = 0;
-            for (int l = 0; l < n; l += 2) {
+            for (short l = 0; l < 8; l += 2) {
                 const int32_t qs = q[l/2+8];
                 s1 += yl[l+8] * (qs & qm[il/2][0]);
                 s2 += yl[l+9] * (qs & qm[il/2][1]);
@@ -4708,7 +4772,7 @@ void kernel_mul_mv_q3_K_f32_impl(
         y1 += 4 * QK_K;
     }
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < nr0; ++row) {
         const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
         sumf1[row] = simd_sum(sumf);
     }
@@ -4716,7 +4780,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
     if (tiisg == 0) {
-        for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+        for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
             dst_f32[first_row + row] = sumf1[row];
         }
     }
@@ -4732,10 +4796,10 @@ kernel void kernel_mul_mv_q3_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_q4_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -4745,22 +4809,22 @@ void kernel_mul_mv_q4_K_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
-
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int ix = tiisg/8;  // 0...3
-    const int it = tiisg%8;  // 0...7
-    const int iq = it/4;     // 0 or 1
-    const int ir = it%4;     // 0...3
+    const short ix = tiisg/8;  // 0...3
+    const short it = tiisg%8;  // 0...7
+    const short iq = it/4;     // 0 or 1
+    const short ir = it%4;     // 0...3
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int first_row = r0 * N_DST;
+
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -4773,7 +4837,8 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     float yl[16];
     float yh[16];
-    float sumf[N_DST]={0.f}, all_sum;
+
+    float sumf[nr0]={0.f};
 
     device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
 
@@ -4782,7 +4847,8 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     for (int ib = ix; ib < nb; ib += 4) {
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int i = 0; i < 8; ++i) {
+
+        for (short i = 0; i < 8; ++i) {
             yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
             yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
             yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
@@ -4793,7 +4859,7 @@ void kernel_mul_mv_q4_K_f32_impl(
         device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
         device const half     * dh = &x[ib].d;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             sc16[0] = sc[0] & kmask1;
             sc16[1] = sc[2] & kmask1;
             sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
@@ -4803,19 +4869,21 @@ void kernel_mul_mv_q4_K_f32_impl(
 
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
-            for (int i = 0; i < 8; i += 2) {
-                acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
-                acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
-                acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
-                acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
-                acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
-                acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
-                acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
-                acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+
+            for (short i = 0; i < 4; ++i) {
+                acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
+                acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
+                acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
+                acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
+                acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
+                acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
+                acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
+                acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
             }
 
             float dall = dh[0];
             float dmin = dh[1];
+
             sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
                                  (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
                                  (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
@@ -4832,10 +4900,10 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
@@ -4850,10 +4918,10 @@ kernel void kernel_mul_mv_q4_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_q5_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -4870,7 +4938,7 @@ void kernel_mul_mv_q5_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -4881,7 +4949,7 @@ void kernel_mul_mv_q5_K_f32_impl(
     device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
     device const float     * yy = (device const float      *) (src1 + offset1);
 
-    float sumf[2]={0.f};
+    float sumf[nr0]={0.f};
 
     float yl[16], yh[16];
 
@@ -4889,15 +4957,14 @@ void kernel_mul_mv_q5_K_f32_impl(
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int tid = tiisg/4;
-    const int ix  = tiisg%4;
-    const int iq  = tid/4;
-    const int ir  = tid%4;
-    const int n   = 8;
+    const short tid = tiisg/4;
+    const short ix  = tiisg%4;
+    const short iq  = tid/4;
+    const short ir  = tid%4;
 
-    const int l0 = n*ir;
-    const int q_offset = 32*iq + l0;
-    const int y_offset = 64*iq + l0;
+    const short l0 = 8*ir;
+    const short q_offset = 32*iq + l0;
+    const short y_offset = 64*iq + l0;
 
     const uint8_t hm1 = 1u << (2*iq);
     const uint8_t hm2 = hm1 << 1;
@@ -4917,14 +4984,14 @@ void kernel_mul_mv_q5_K_f32_impl(
 
         device const float * y2 = y1 + 128;
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int l = 0; l < 8; ++l) {
+        for (short l = 0; l < 8; ++l) {
             yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
             yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
             yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
             yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
         }
 
-        for (int row = 0; row < 2; ++row) {
+        for (short row = 0; row < nr0; ++row) {
             device const uint8_t * q2 = q1 + 64;
 
             sc16[0] = a[0] & kmask1;
@@ -4934,7 +5001,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
             float4 acc1 = {0.f};
             float4 acc2 = {0.f};
-            for (int l = 0; l < n; ++l) {
+            for (short l = 0; l < 8; ++l) {
                 uint8_t h = qh[l];
                 acc1[0] += yl[l+0] * (q1[l] & 0x0F);
                 acc1[1] += yl[l+8] * (q1[l] & 0xF0);
@@ -4964,7 +5031,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = tot;
@@ -4982,10 +5049,10 @@ kernel void kernel_mul_mv_q5_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template 
+template
 void kernel_mul_mv_q6_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -5007,62 +5074,77 @@ void kernel_mul_mv_q6_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int row = 2*r0 + sgitg;
-
-    if (row >= args.ne0) {
-        return;
-    }
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
-    const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
-    const uint64_t offset1 =  r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
     device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
     device const float     * yy = (device const float      *) (src1 + offset1);
 
-    float sumf = 0;
+    float sumf[nr0] = { 0.f };
 
-    const int tid  = tiisg/2;
-    const int ix   = tiisg%2;
-    const int ip   = tid/8;         // 0 or 1
-    const int il   = tid%8;
-    const int n    = 4;
-    const int l0   = n*il;
-    const int is   = 8*ip + l0/16;
+    float yl[16];
 
-    const int y_offset = 128*ip + l0;
-    const int q_offset_l = 64*ip + l0;
-    const int q_offset_h = 32*ip + l0;
+    const short tid = tiisg/2;
+    const short ix  = tiisg%2;
+    const short ip  = tid/8;         // 0 or 1
+    const short il  = tid%8;
+    const short l0  = 4*il;
+    const short is  = 8*ip + l0/16;
+
+    const short y_offset   = 128*ip + l0;
+    const short q_offset_l =  64*ip + l0;
+    const short q_offset_h =  32*ip + l0;
 
     for (int i = ix; i < nb; i += 2) {
         device const uint8_t * q1 = x[i].ql + q_offset_l;
         device const uint8_t * q2 = q1 + 32;
         device const uint8_t * qh = x[i].qh + q_offset_h;
         device const int8_t  * sc = x[i].scales + is;
+        device const half    * dh = &x[i].d;
 
         device const float * y = yy + i * QK_K + y_offset;
 
-        const float dall = x[i].d;
-
-        float4 sums = {0.f, 0.f, 0.f, 0.f};
-        for (int l = 0; l < n; ++l) {
-            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
-            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
-            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
-            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+        for (short l = 0; l < 4; ++l) {
+            yl[4*l + 0] = y[l +  0];
+            yl[4*l + 1] = y[l + 32];
+            yl[4*l + 2] = y[l + 64];
+            yl[4*l + 3] = y[l + 96];
         }
 
-        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+        for (short row = 0; row < nr0; ++row) {
+            const float dall = dh[0];
 
+            float4 sums = {0.f, 0.f, 0.f, 0.f};
+
+            for (short l = 0; l < 4; ++l) {
+                sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+                sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+                sums[2] += yl[4*l + 2] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+                sums[3] += yl[4*l + 3] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+            }
+
+            sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+            q1 += args.nb01;
+            q2 += args.nb01;
+            qh += args.nb01;
+            sc += args.nb01;
+            dh += args.nb01/2;
+        }
     }
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    const float tot = simd_sum(sumf);
-    if (tiisg == 0) {
-        dst_f32[row] = tot;
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = sum_all;
+        }
     }
 }
 
@@ -5076,12 +5158,12 @@ kernel void kernel_mul_mv_q6_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 // ======================= "True" 2-bit
 
-template
+template
 void kernel_mul_mv_iq2_xxs_f32_impl(
         args_t args,
         device const char * src0,
@@ -5097,7 +5179,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5109,7 +5191,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     device const float         * y = (device const float         *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -5130,8 +5212,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -5142,18 +5223,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
         device const uint16_t * q2 = xr->qs + 4 * ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             device const uint8_t * aux8 = (device const uint8_t *)q2;
             const uint32_t aux32 = q2[2] | (q2[3] << 16);
             const float d = db * (0.5f + (aux32 >> 28));
 
             float sum = 0;
-            for (int l = 0; l < 4; ++l) {
+            for (short l = 0; l < 4; ++l) {
                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
                 const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
@@ -5168,10 +5248,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = sum_all * 0.25f;
         }
     }
 }
@@ -5186,10 +5266,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_iq2_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -5205,7 +5285,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5217,7 +5297,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     device const float        * y = (device const float        *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -5238,8 +5318,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -5251,8 +5330,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
         device const uint8_t  * sc = xr->scales + ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const uint8_t ls1 = sc[0] & 0xf;
             const uint8_t ls2 = sc[0] >>  4;
@@ -5260,17 +5338,17 @@ void kernel_mul_mv_iq2_xs_f32_impl(
             const float d2 = db * (0.5f + ls2);
 
             float sum1 = 0, sum2 = 0;
-            for (int l = 0; l < 2; ++l) {
+            for (short l = 0; l < 2; ++l) {
                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
                 const uint8_t signs = ssigns[(q2[l] >> 9)];
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
-            for (int l = 2; l < 4; ++l) {
+            for (short l = 2; l < 4; ++l) {
                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
                 const uint8_t signs = ssigns[(q2[l] >> 9)];
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
@@ -5286,10 +5364,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = sum_all * 0.25f;
         }
     }
 }
@@ -5305,10 +5383,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template 
+template
 void kernel_mul_mv_iq3_xxs_f32_impl(
         args_t args,
         device const char * src0,
@@ -5324,7 +5402,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5336,7 +5414,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     device const float         * y = (device const float         *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -5357,7 +5435,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -5369,17 +5447,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
         device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const uint32_t aux32 = gas[0] | (gas[1] << 16);
             const float d = db * (0.5f + (aux32 >> 28));
 
             float2 sum = {0};
-            for (int l = 0; l < 4; ++l) {
+            for (short l = 0; l < 4; ++l) {
                 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
                 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
                 const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 4; ++j) {
+                for (short j = 0; j < 4; ++j) {
                     sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
                     sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
                 }
@@ -5396,10 +5474,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.5f;
+            dst_f32[first_row + row] = sum_all * 0.5f;
         }
     }
 }
@@ -5415,10 +5493,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_iq3_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -5434,7 +5512,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5446,7 +5524,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -5463,8 +5541,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -5478,18 +5555,17 @@ void kernel_mul_mv_iq3_s_f32_impl(
         device const uint8_t * signs = xr->signs + 4 * ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
 
             float2 sum = {0};
-            for (int l = 0; l < 4; ++l) {
+            for (short l = 0; l < 4; ++l) {
                 const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
                 const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
                 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
                 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
-                for (int j = 0; j < 4; ++j) {
+                for (short j = 0; j < 4; ++j) {
                     sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
                     sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
                 }
@@ -5508,10 +5584,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
@@ -5527,10 +5603,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template 
+template
 void kernel_mul_mv_iq2_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -5546,7 +5622,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5558,7 +5634,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -5570,13 +5646,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
     //    threadgroup_barrier(mem_flags::mem_threadgroup);
     //}
 
-    const int ix = tiisg;
+    const short ix = tiisg;
 
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -5590,19 +5665,18 @@ void kernel_mul_mv_iq2_s_f32_impl(
         device const uint8_t * signs = qs + QK_K/8;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const float d1 = db * (0.5f + (sc[0] & 0xf));
             const float d2 = db * (0.5f + (sc[0] >>  4));
 
             float2 sum = {0};
-            for (int l = 0; l < 2; ++l) {
+            for (short l = 0; l < 2; ++l) {
                 //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
                 //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
                 constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
                 constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum[0] += yl[8*l + j +  0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
                     sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
                 }
@@ -5621,10 +5695,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = sum_all * 0.25f;
         }
     }
 }
@@ -5640,10 +5714,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template
+template
 void kernel_mul_mv_iq1_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -5659,7 +5733,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5671,18 +5745,17 @@ void kernel_mul_mv_iq1_s_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
-    const int ix = tiisg;
+    const short ix = tiisg;
 
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
         float sumy = 0;
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
             sumy += yl[i];
         }
@@ -5695,15 +5768,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
         device const uint16_t * qh = xr->qh + ib;
         device const half     * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
             constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
             constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
             constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
 
             float sum = 0;
-            for (int j = 0; j < 4; ++j) {
+            for (short j = 0; j < 4; ++j) {
                 sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
                      + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
                      + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -5721,15 +5793,28 @@ void kernel_mul_mv_iq1_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-template 
+[[host_name("kernel_mul_mv_iq1_s_f32")]]
+kernel void kernel_mul_mv_iq1_s_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
 void kernel_mul_mv_iq1_m_f32_impl(
         args_t args,
         device const char * src0,
@@ -5741,11 +5826,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
         ushort sgitg) {
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5757,20 +5843,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
-    const int ix = tiisg;
+    const short ix = tiisg;
 
     device const float * y4 = y + 32 * ix;
 
     iq1m_scale_t scale;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
         float4 sumy = {0.f};
-        for (int i = 0; i < 8; ++i) {
+        for (short i = 0; i < 8; ++i) {
             yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
             yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
             yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
@@ -5785,7 +5870,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
         device const uint8_t  * qh = xr->qh + 2 * ib;
         device const uint16_t * sc = (device const uint16_t *)xr->scales;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
 
             constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -5794,7 +5879,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
             constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
 
             float2 sum = {0.f};
-            for (int j = 0; j < 4; ++j) {
+            for (short j = 0; j < 4; ++j) {
                 sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
                         + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
                 sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -5816,15 +5901,28 @@ void kernel_mul_mv_iq1_m_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-template
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
 void kernel_mul_mv_iq4_nl_f32_impl(
         args_t args,
         device const char * src0,
@@ -5837,10 +5935,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     threadgroup float * shmem_f32 = (threadgroup float *) shmem;
     const int nb = args.ne00/QK4_NL;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    const int first_row = (r0 * 2 + sgitg) * 2;
+
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5851,14 +5951,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
     device const float        * y = (device const float        *) (src1 + offset1);
 
-    const int ix = tiisg/2;  // 0...15
-    const int it = tiisg%2;  // 0 or 1
+    const short ix = tiisg/2;  // 0...15
+    const short it = tiisg%2;  // 0 or 1
 
     shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
-    float sumf[2]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     device const float * yb = y + ix * QK4_NL + it * 8;
 
@@ -5868,12 +5968,13 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     float4 qf1, qf2;
 
     for (int ib = ix; ib < nb; ib += 16) {
-
         device const float4 * y4 = (device const float4 *)yb;
-        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
-
-        for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+        yl[0] = y4[0];
+        yl[1] = y4[4];
+        yl[2] = y4[1];
+        yl[3] = y4[5];
 
+        for (short row = 0; row < nr0; row++) {
             device const block_iq4_nl & xb = x[row*nb + ib];
             device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
 
@@ -5898,7 +5999,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
             acc1 += acc2;
 
             sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
         }
 
         yb += 16 * QK4_NL;
@@ -5906,15 +6006,29 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-template
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
+kernel void kernel_mul_mv_iq4_nl_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
 void kernel_mul_mv_iq4_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -5930,7 +6044,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    const int first_row = (r0 * 2 + sgitg) * 2;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5941,16 +6055,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
     device const float        * y = (device const float        *) (src1 + offset1);
 
-    const int ix = tiisg/16;  // 0 or 1
-    const int it = tiisg%16;  // 0...15
-    const int ib = it/2;
-    const int il = it%2;
+    const short ix = tiisg/16;  // 0 or 1
+    const short it = tiisg%16;  // 0...15
+    const short ib = it/2;
+    const short il = it%2;
 
     shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
-    float sumf[2]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
 
@@ -5961,9 +6075,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     for (int ibl = ix; ibl < nb; ibl += 2) {
         device const float4 * y4 = (device const float4 *)yb;
-        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+        yl[0] = y4[0];
+        yl[1] = y4[4];
+        yl[2] = y4[1];
+        yl[3] = y4[5];
 
-        for (int row = 0; row < 2; ++row) {
+        for (short row = 0; row < nr0; ++row) {
             device const block_iq4_xs & xb = x[row*nb + ibl];
             device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
 
@@ -5987,7 +6104,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
             const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
             sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
         }
 
         yb += 2 * QK_K;
@@ -5995,54 +6111,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-[[host_name("kernel_mul_mv_iq1_s_f32")]]
-kernel void kernel_mul_mv_iq1_s_f32(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq1_m_f32")]]
-kernel void kernel_mul_mv_iq1_m_f32(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq4_nl_f32")]]
-kernel void kernel_mul_mv_iq4_nl_f32(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        threadgroup  char * shmem [[threadgroup(0)]],
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
-}
-
 [[host_name("kernel_mul_mv_iq4_xs_f32")]]
 kernel void kernel_mul_mv_iq4_xs_f32(
         constant ggml_metal_kargs_mul_mv & args,
@@ -6054,7 +6130,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 template
@@ -6062,28 +6138,21 @@ kernel void kernel_get_rows_q(
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
 
     const int64_t i02 = i11;
 
-    for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+    for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
         float4x4 temp;
-        dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
-        *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+        dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
+        *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
     }
 }
 
@@ -6092,27 +6161,20 @@ kernel void kernel_get_rows_f(
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
 
     const int64_t i02 = i11;
 
-    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((      device float *) ((      device char *)  dst + i11*nb2  + i10*nb1))[ind] =
-        ((const device T     *) ((const device char *) src0 + i02*nb02 +  r*nb01))[ind];
+    for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
+        ((      device float *) ((      device char *)  dst + i11*args.nb2  + i10*args.nb1))[ind] =
+        ((const device T     *) ((const device char *) src0 + i02*args.nb02 +  r*args.nb01))[ind];
     }
 }
 
@@ -6120,27 +6182,20 @@ kernel void kernel_get_rows_i32(
         device const  void * src0,
         device const  void * src1,
         device     int32_t * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
 
     const int64_t i02 = i11;
 
-    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((      device int32_t *) ((      device char *) dst  + i11*nb2 + i10*nb1))[ind] =
-        ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+    for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
+        ((      device int32_t *) ((      device char *) dst  + i11*args.nb2 + i10*args.nb1))[ind] =
+        ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
     }
 }
 
@@ -6719,121 +6774,103 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
 #endif
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
-template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
 
 kernel void kernel_pool_2d_max_f32(
         device  const float * src0,
         device        float * dst,
-        constant    int32_t & k0,
-        constant    int32_t & k1,
-        constant    int32_t & s0,
-        constant    int32_t & s1,
-        constant    int32_t & p0,
-        constant    int32_t & p1,
-        constant    int64_t & IH,
-        constant    int64_t & IW,
-        constant    int64_t & OH,
-        constant    int64_t & OW,
-        constant    int64_t & parallel_elements,
+        constant    ggml_metal_kargs_pool_2d & args,
         uint        gid[[thread_position_in_grid]]) {
 
-    if (gid >= parallel_elements) {
+    if (gid >= args.parallel_elements) {
         return;
     }
 
     const int idx = gid;
-    const int I_HW = IH * IW;
-    const int O_HW = OH * OW;
+    const int I_HW = args.IH * args.IW;
+    const int O_HW = args.OH * args.OW;
     const int nc = idx / O_HW;
-    const int cur_oh = idx % O_HW / OW;
-    const int cur_ow = idx % O_HW % OW;
+    const int cur_oh = idx % O_HW / args.OW;
+    const int cur_ow = idx % O_HW % args.OW;
 
     device const float * i_ptr = src0 + nc * I_HW;
     device       float * o_ptr = dst  + nc * O_HW;
 
-    const int start_h = cur_oh * s1 - p1;
+    const int start_h = cur_oh * args.s1 - args.p1;
     const int bh = MAX(0,  start_h);
-    const int eh = MIN(IH, start_h + k1);
-    const int start_w = cur_ow * s0 - p0;
+    const int eh = MIN(args.IH, start_h + args.k1);
+    const int start_w = cur_ow * args.s0 - args.p0;
     const int bw = MAX(0,  start_w);
-    const int ew = MIN(IW, start_w + k0);
+    const int ew = MIN(args.IW, start_w + args.k0);
 
     float res = -INFINITY;
 
     for (int i = bh; i < eh; i += 1) {
         for (int j = bw; j < ew; j += 1) {
-            res = MAX(res, i_ptr[i * IW + j]);
+            res = MAX(res, i_ptr[i * args.IW + j]);
         }
     }
 
-    o_ptr[cur_oh * OW + cur_ow] = res;
+    o_ptr[cur_oh * args.OW + cur_ow] = res;
 }
 
 kernel void kernel_pool_2d_avg_f32(
         device  const float * src0,
         device        float * dst,
-        constant    int32_t & k0,
-        constant    int32_t & k1,
-        constant    int32_t & s0,
-        constant    int32_t & s1,
-        constant    int32_t & p0,
-        constant    int32_t & p1,
-        constant    int64_t & IH,
-        constant    int64_t & IW,
-        constant    int64_t & OH,
-        constant    int64_t & OW,
-        constant    int64_t & parallel_elements,
+        constant    ggml_metal_kargs_pool_2d & args,
         uint        gid[[thread_position_in_grid]]) {
 
-    if (gid >= parallel_elements) {
+    if (gid >= args.parallel_elements) {
         return;
     }
 
     const int idx = gid;
-    const int I_HW = IH * IW;
-    const int O_HW = OH * OW;
+    const int I_HW = args.IH * args.IW;
+    const int O_HW = args.OH * args.OW;
     const int nc = idx / O_HW;
-    const int cur_oh = idx % O_HW / OW;
-    const int cur_ow = idx % O_HW % OW;
+    const int cur_oh = idx % O_HW / args.OW;
+    const int cur_ow = idx % O_HW % args.OW;
 
     device const float * i_ptr = src0 + nc * I_HW;
     device       float * o_ptr = dst  + nc * O_HW;
 
-    const int start_h = cur_oh * s1 - p1;
+    const int start_h = cur_oh * args.s1 - args.p1;
     const int bh = MAX(0,  start_h);
-    const int eh = MIN(IH, start_h + k1);
-    const int start_w = cur_ow * s0 - p0;
+    const int eh = MIN(args.IH, start_h + args.k1);
+    const int start_w = cur_ow * args.s0 - args.p0;
     const int bw = MAX(0,  start_w);
-    const int ew = MIN(IW, start_w + k0);
+    const int ew = MIN(args.IW, start_w + args.k0);
     // const float scale = 1. / ((eh - bh) * (ew - bw));
-    const float scale = 1. / (k0 * k1);
+    const float scale = 1. / (args.k0 * args.k1);
 
     float res = 0;
 
     for (int i = bh; i < eh; i += 1) {
         for (int j = bw; j < ew; j += 1) {
-            float cur = i_ptr[i * IW + j];
+            float cur = i_ptr[i * args.IW + j];
             res += cur * scale;
         }
     }
 
-    o_ptr[cur_oh * OW + cur_ow] = res;
+    o_ptr[cur_oh * args.OW + cur_ow] = res;
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-quants.c b/ml/backend/ggml/ggml/src/ggml-quants.c
index 7918388ae..ac918a60d 100644
--- a/ml/backend/ggml/ggml/src/ggml-quants.c
+++ b/ml/backend/ggml/ggml/src/ggml-quants.c
@@ -28,7 +28,7 @@
 #define UNUSED GGML_UNUSED
 
 // reference implementation for deterministic creation of model files
-void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
+void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK4_0;
 
     assert(k % qk == 0);
@@ -65,7 +65,7 @@ void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, in
     }
 }
 
-void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
+void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k) {
     const int qk = QK4_1;
 
     assert(k % qk == 0);
@@ -102,7 +102,7 @@ void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, in
     }
 }
 
-void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
+void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK5_0;
 
     assert(k % qk == 0);
@@ -146,7 +146,7 @@ void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, in
     }
 }
 
-void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
+void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k) {
     const int qk = QK5_1;
 
     assert(k % qk == 0);
@@ -191,7 +191,7 @@ void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, in
 }
 
 // reference implementation for deterministic creation of model files
-void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
+void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k) {
     assert(k % QK8_0 == 0);
     const int nb = k / QK8_0;
 
@@ -217,7 +217,7 @@ void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, in
 }
 
 // reference implementation for deterministic creation of model files
-void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
+void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k) {
     assert(QK8_1 == 32);
     assert(k % QK8_1 == 0);
     const int nb = k / QK8_1;
@@ -252,7 +252,7 @@ void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, in
     }
 }
 
-void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK4_0;
 
     assert(k % qk == 0);
@@ -272,7 +272,7 @@ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int6
     }
 }
 
-void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK4_1;
 
     assert(k % qk == 0);
@@ -293,7 +293,7 @@ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int6
     }
 }
 
-void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK5_0;
 
     assert(k % qk == 0);
@@ -319,7 +319,7 @@ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int6
     }
 }
 
-void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK5_1;
 
     assert(k % qk == 0);
@@ -346,7 +346,7 @@ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int6
     }
 }
 
-void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK8_0;
 
     assert(k % qk == 0);
@@ -376,8 +376,8 @@ static inline int nearest_int(float fval) {
     return (i & 0x007fffff) - 0x00400000;
 }
 
-static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
-        const float * restrict qw) {
+static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, int rmse_type,
+        const float * GGML_RESTRICT qw) {
     float max = 0;
     float amax = 0;
     for (int i = 0; i < n; ++i) {
@@ -445,7 +445,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
     return scale;
 }
 
-static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
+static float make_q3_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, bool do_rmse) {
     float max = 0;
     float amax = 0;
     for (int i = 0; i < n; ++i) {
@@ -504,7 +504,7 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
     return 1/iscale;
 }
 
-static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
+static float make_qkx1_quants(int n, int nmax, const float * GGML_RESTRICT x, uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min,
         int ntry, float alpha) {
     float min = x[0];
     float max = x[0];
@@ -547,8 +547,8 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
     return scale;
 }
 
-static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
-        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights,
+        uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, uint8_t * GGML_RESTRICT Laux,
         float rmin, float rdelta, int nstep, bool use_mad) {
     float min = x[0];
     float max = x[0];
@@ -628,7 +628,7 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const f
     return scale;
 }
 
-static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
+static inline void get_scale_min_k4(int j, const uint8_t * GGML_RESTRICT q, uint8_t * GGML_RESTRICT d, uint8_t * GGML_RESTRICT m) {
     if (j < 4) {
         *d = q[j] & 63; *m = q[j + 4] & 63;
     } else {
@@ -639,7 +639,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
 
 //========================- 2-bit (de)-quantization
 
-void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) {
+void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -709,7 +709,7 @@ void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, in
     }
 }
 
-void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -741,8 +741,8 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6
     }
 }
 
-static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
-        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights,
+        uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, uint8_t * GGML_RESTRICT Laux,
         float rmin, float rdelta, int nstep, bool use_mad) {
     float min = x[0];
     float max = x[0];
@@ -824,7 +824,7 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
     return scale;
 }
 
-static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) {
+static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint8_t * GGML_RESTRICT L, const float * quant_weights) {
     float max = 0;
     for (int i = 0; i < n; ++i) {
         max = MAX(max, x[i]);
@@ -897,7 +897,7 @@ static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t *
     return sumlx/suml2;
 }
 
-static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) {
+static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k, const float * GGML_RESTRICT quant_weights) {
     GGML_ASSERT(quant_weights);
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
@@ -917,7 +917,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
         for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
         float sigma2 = sumx2/QK_K;
         for (int j = 0; j < QK_K/16; ++j) {
-            const float * restrict qw = quant_weights + QK_K * i + 16*j;
+            const float * GGML_RESTRICT qw = quant_weights + QK_K * i + 16*j;
             for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
             for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
             scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
@@ -959,7 +959,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
     }
 }
 
-size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
     if (!quant_weights) {
         quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row);
@@ -977,7 +977,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nr
 
 //========================= 3-bit (de)-quantization
 
-void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) {
+void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -1053,7 +1053,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
     }
 }
 
-void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -1067,8 +1067,8 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6
 
         const float d_all = GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict q = x[i].qs;
-        const uint8_t * restrict hm = x[i].hmask;
+        const uint8_t * GGML_RESTRICT q = x[i].qs;
+        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
         uint8_t m = 1;
 
         memcpy(aux, x[i].scales, 12);
@@ -1103,7 +1103,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6
     }
 }
 
-static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
+static void quantize_row_q3_K_impl(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t n_per_row, const float * GGML_RESTRICT quant_weights) {
     assert(n_per_row % QK_K == 0);
     const int nb = n_per_row / QK_K;
 
@@ -1187,7 +1187,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
     }
 }
 
-size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
     if (!quant_weights) {
         quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row);
@@ -1205,7 +1205,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr
 
 // ====================== 4-bit (de)-quantization
 
-void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) {
+void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -1277,7 +1277,7 @@ void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, in
     }
 }
 
-void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -1301,7 +1301,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6
     }
 }
 
-static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
+static void quantize_row_q4_K_impl(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
     assert(n_per_row % QK_K == 0);
     const int64_t nb = n_per_row / QK_K;
 
@@ -1374,7 +1374,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
     }
 }
 
-size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
     if (!quant_weights) {
         quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row);
@@ -1392,7 +1392,7 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nr
 
 // ====================== 5-bit (de)-quantization
 
-void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) {
+void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -1454,8 +1454,8 @@ void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, in
             }
         }
 
-        uint8_t * restrict qh = y[i].qh;
-        uint8_t * restrict ql = y[i].qs;
+        uint8_t * GGML_RESTRICT qh = y[i].qh;
+        uint8_t * GGML_RESTRICT ql = y[i].qs;
         memset(qh, 0, QK_K/8);
 
         uint8_t m1 = 1, m2 = 2;
@@ -1479,7 +1479,7 @@ void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, in
     }
 }
 
-void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -1506,7 +1506,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6
     }
 }
 
-static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
+static void quantize_row_q5_K_impl(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
     assert(n_per_row % QK_K == 0);
     const int64_t nb = n_per_row / QK_K;
 
@@ -1573,8 +1573,8 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
             }
         }
 
-        uint8_t * restrict qh = y[i].qh;
-        uint8_t * restrict ql = y[i].qs;
+        uint8_t * GGML_RESTRICT qh = y[i].qh;
+        uint8_t * GGML_RESTRICT ql = y[i].qs;
         memset(qh, 0, QK_K/8);
 
         uint8_t m1 = 1, m2 = 2;
@@ -1599,7 +1599,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
     }
 }
 
-size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
     if (!quant_weights) {
         quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row);
@@ -1617,7 +1617,7 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nr
 
 // ====================== 6-bit (de)-quantization
 
-void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) {
+void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -1667,8 +1667,8 @@ void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, in
             }
         }
 
-        uint8_t * restrict ql = y[i].ql;
-        uint8_t * restrict qh = y[i].qh;
+        uint8_t * GGML_RESTRICT ql = y[i].ql;
+        uint8_t * GGML_RESTRICT qh = y[i].qh;
         for (int j = 0; j < QK_K; j += 128) {
             for (int l = 0; l < 32; ++l) {
                 const uint8_t q1 = L[j + l +  0] & 0xF;
@@ -1687,16 +1687,16 @@ void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, in
     }
 }
 
-void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
     for (int i = 0; i < nb; i++) {
         const float d = GGML_FP16_TO_FP32(x[i].d);
 
-        const uint8_t * restrict ql = x[i].ql;
-        const uint8_t * restrict qh = x[i].qh;
-        const int8_t  * restrict sc = x[i].scales;
+        const uint8_t * GGML_RESTRICT ql = x[i].ql;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t  * GGML_RESTRICT sc = x[i].scales;
 
         for (int n = 0; n < QK_K; n += 128) {
             for (int l = 0; l < 32; ++l) {
@@ -1718,7 +1718,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6
     }
 }
 
-static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
+static void quantize_row_q6_K_impl(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
     assert(n_per_row % QK_K == 0);
     const int64_t nb = n_per_row / QK_K;
 
@@ -1781,8 +1781,8 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
             }
         }
 
-        uint8_t * restrict ql = y[i].ql;
-        uint8_t * restrict qh = y[i].qh;
+        uint8_t * GGML_RESTRICT ql = y[i].ql;
+        uint8_t * GGML_RESTRICT qh = y[i].qh;
         for (int j = 0; j < QK_K; j += 128) {
             for (int l = 0; l < 32; ++l) {
                 const uint8_t q1 = L[j + l +  0] & 0xF;
@@ -1802,7 +1802,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
     }
 }
 
-size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
     if (!quant_weights) {
         quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row);
@@ -1818,7 +1818,7 @@ size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nr
     return nrow * row_size;
 }
 
-static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
+static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
     static_assert(QK4_0 == 32, "QK4_0 must be 32");
 
     if (!quant_weights) {
@@ -1846,7 +1846,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
     }
 }
 
-size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     if (!quant_weights) {
         quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);
         return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
@@ -1861,7 +1861,7 @@ size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nr
     return nrow * row_size;
 }
 
-static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int64_t n_per_row, const float * quant_weights) {
+static void quantize_row_q4_1_impl(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
     static_assert(QK4_1 == 32, "QK4_1 must be 32");
 
     if (!quant_weights) {
@@ -1891,7 +1891,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
     }
 }
 
-size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     if (!quant_weights) {
         quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row);
         return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
@@ -1906,7 +1906,7 @@ size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nr
     return nrow * row_size;
 }
 
-static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
+static void quantize_row_q5_0_impl(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
     static_assert(QK5_0 == 32, "QK5_0 must be 32");
 
     if (!quant_weights) {
@@ -1945,7 +1945,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
     }
 }
 
-size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     if (!quant_weights) {
         quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row);
         return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
@@ -1960,7 +1960,7 @@ size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nr
     return nrow * row_size;
 }
 
-static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int64_t n_per_row, const float * quant_weights) {
+static void quantize_row_q5_1_impl(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
     static_assert(QK5_1 == 32, "QK5_1 must be 32");
 
     if (!quant_weights) {
@@ -1998,7 +1998,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
     }
 }
 
-size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     if (!quant_weights) {
         quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row);
         return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
@@ -2013,7 +2013,7 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr
     return nrow * row_size;
 }
 
-size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     (void)quant_weights; // not used
     const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
     quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row);
@@ -2022,7 +2022,7 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
 
 // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
 
-void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) {
+void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2088,7 +2088,7 @@ void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y,
     }
 }
 
-void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) {
+void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2120,21 +2120,21 @@ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y,
     }
 }
 
-size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     (void)quant_weights; // not used
     const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row);
     quantize_row_tq1_0_ref(src, dst, (int64_t)nrow*n_per_row);
     return nrow * row_size;
 }
 
-size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     (void)quant_weights; // not used
     const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row);
     quantize_row_tq2_0_ref(src, dst, (int64_t)nrow*n_per_row);
     return nrow * row_size;
 }
 
-void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2173,7 +2173,7 @@ void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, in
     }
 }
 
-void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2194,7 +2194,7 @@ void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, in
 
 // ====================== "True" 2-bit (de)-quantization
 
-void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2222,7 +2222,7 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y
 
 // ====================== 2.3125 bpw (de)-quantization
 
-void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_iq2_xs(const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2249,7 +2249,7 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
 
 // ====================== 2.5625 bpw (de)-quantization
 
-void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_iq2_s(const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2281,7 +2281,7 @@ void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, in
 
 // ====================== 3.0625 bpw (de)-quantization
 
-void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2313,7 +2313,7 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
 
 // ====================== 3.3125 bpw (de)-quantization
 
-void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_iq3_s(const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2356,7 +2356,7 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
 
 // ====================== 1.5625 bpw (de)-quantization
 
-void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_iq1_s(const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2381,7 +2381,7 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
     }
 }
 
-void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_iq1_m(const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2433,7 +2433,7 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in
 
 static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
 
-void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_iq4_nl(const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK4_NL == 0);
     const int64_t nb = k / QK4_NL;
 
@@ -2451,7 +2451,7 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y,
     }
 }
 
-void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_iq4_xs(const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2476,7 +2476,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y,
 
 //===================================== Q8_K ==============================================
 
-void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) {
+void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2515,7 +2515,7 @@ void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, in
     }
 }
 
-void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) {
+void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     const int64_t nb = k / QK_K;
 
@@ -2927,8 +2927,8 @@ void iq2xs_free_impl(enum ggml_type type) {
     }
 }
 
-static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
-        const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
+static int iq2_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid,
+        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, int8_t * GGML_RESTRICT L) {
     int num_neighbors = neighbours[0];
     GGML_ASSERT(num_neighbors > 0);
     float best_d2 = FLT_MAX;
@@ -2951,7 +2951,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u
     return grid_index;
 }
 
-static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
+static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) {
 
     const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
 
@@ -3124,7 +3124,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
     }
 }
 
-static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
+static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) {
 
     const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
 
@@ -3304,7 +3304,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
     }
 }
 
-size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
     int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
@@ -3316,7 +3316,7 @@ size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int64_t
     return nrow * nblock * sizeof(block_iq2_xxs);
 }
 
-size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_iq2_xs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
     int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
@@ -3521,8 +3521,8 @@ void iq3xs_free_impl(int grid_size) {
     }
 }
 
-static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
-        const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
+static int iq3_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint32_t * GGML_RESTRICT grid,
+        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, int8_t * GGML_RESTRICT L) {
     int num_neighbors = neighbours[0];
     GGML_ASSERT(num_neighbors > 0);
     float best_d2 = FLT_MAX;
@@ -3545,8 +3545,8 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u
     return grid_index;
 }
 
-static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int64_t n,
-        const float * restrict quant_weights) {
+static void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n,
+        const float * GGML_RESTRICT quant_weights) {
 
     const int gindex = iq3_data_index(grid_size);
 
@@ -3758,7 +3758,7 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v
     }
 }
 
-size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
     int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
@@ -3770,13 +3770,13 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t
     return nrow * nblock * sizeof(block_iq3_xxs);
 }
 
-void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
+void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
 }
 
-static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, void * restrict vy, int n,
-        const float * restrict quant_weights,
+static void quantize_row_iq3_s_impl(int block_size, const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int n,
+        const float * GGML_RESTRICT quant_weights,
         float   * scales,
         float   * weight,
         float   * xval,
@@ -3958,7 +3958,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
 }
 
 #define IQ3S_BLOCK_SIZE 32
-size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_iq3_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
     int64_t nblock = n_per_row/QK_K;
     float scales[QK_K/IQ3S_BLOCK_SIZE];
@@ -3980,7 +3980,7 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n
     return nrow * nblock * sizeof(block_iq3_s);
 }
 
-void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
+void quantize_row_iq3_s_ref(const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     quantize_iq3_s(x, y, 1, k, NULL);
 }
@@ -3988,8 +3988,8 @@ void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y,
 
 // =================================== 1.5 bpw ===================================================
 
-static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
-        const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
+static int iq1_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid,
+        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float * scale, int8_t * GGML_RESTRICT L, int ngrid) {
     int num_neighbors = neighbours[0];
     GGML_ASSERT(num_neighbors > 0);
     float best_score = -FLT_MAX;
@@ -4048,8 +4048,8 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u
     return grid_index;
 }
 
-static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
-        const float * restrict xval, const float * restrict weight, float scale, const float * restrict xg, int8_t * restrict L, int ngrid) {
+static int iq1_find_best_neighbour2(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid,
+        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, const float * GGML_RESTRICT xg, int8_t * GGML_RESTRICT L, int ngrid) {
     int num_neighbors = neighbours[0];
     GGML_ASSERT(num_neighbors > 0);
     float best_score = FLT_MAX;
@@ -4113,7 +4113,7 @@ static int iq1_sort_helper(const void * left, const void * right) {
 
 #define IQ1S_BLOCK_SIZE 32
 #define IQ1M_BLOCK_SIZE 16
-static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights,
+static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights,
         float    * scales,
         float    * weight,
         float    * sumx,
@@ -4271,7 +4271,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
     }
 }
 
-size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_iq1_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
     float  scales[QK_K/IQ1S_BLOCK_SIZE];
     float  weight[IQ1S_BLOCK_SIZE];
@@ -4291,7 +4291,7 @@ size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t n
     return nrow * nblock * sizeof(block_iq1_s);
 }
 
-static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights,
+static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights,
         float    * scales,
         float    * weight,
         float    * pairs,
@@ -4539,7 +4539,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
     }
 }
 
-size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
     float  scales[QK_K/IQ1M_BLOCK_SIZE];
     float  weight[IQ1M_BLOCK_SIZE];
@@ -4570,7 +4570,7 @@ static inline int best_index_int8(int n, const int8_t * val, float x) {
     return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
 }
 
-static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * restrict x,
+static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
         ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
         float * scales, float * weight, uint8_t * L,
         const int8_t * values,
@@ -4681,7 +4681,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
     }
 }
 
-size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_iq4_nl(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK4_NL == 0);
     int64_t nblock = n_per_row/QK4_NL;
     char * qrow = (char *)dst;
@@ -4703,8 +4703,8 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t
     return nrow * nblock * sizeof(block_iq4_nl);
 }
 
-//void quantize_row_iq4_nl_ref(const float * restrict x, void * restrict vy, int64_t k) {
-void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
+//void quantize_row_iq4_nl_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+void quantize_row_iq4_nl_ref(const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k) {
     GGML_ASSERT(k%QK4_NL == 0);
     int64_t nblock = k/QK4_NL;
     uint8_t L[QK4_NL];
@@ -4719,7 +4719,7 @@ void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y
     }
 }
 
-size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_iq4_xs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
     int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
@@ -4739,14 +4739,14 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t
     return nrow * nblock * sizeof(block_iq4_xs);
 }
 
-void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
+void quantize_row_iq4_xs_ref(const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     quantize_iq4_xs(x, y, 1, k, NULL);
 }
 
 // =============================== 2.5625 bpw
 
-static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
+static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) {
 
     const int gindex = iq2_data_index(GGML_TYPE_IQ2_S);
 
@@ -4914,7 +4914,7 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy
     }
 }
 
-size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+size_t quantize_iq2_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
     int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
@@ -4926,7 +4926,7 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t n
     return nrow * nblock * sizeof(block_iq2_s);
 }
 
-void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
+void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k) {
     assert(k % QK_K == 0);
     quantize_iq2_s(x, y, 1, k, NULL);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml.c b/ml/backend/ggml/ggml/src/ggml.c
index 635aa2990..2276b6312 100644
--- a/ml/backend/ggml/ggml/src/ggml.c
+++ b/ml/backend/ggml/ggml/src/ggml.c
@@ -565,9 +565,9 @@ FILE * ggml_fopen(const char * fname, const char * mode) {
 #endif
 
 }
-static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
-static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
-static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
+static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
+static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
 
 static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
     [GGML_TYPE_I8] = {
@@ -929,6 +929,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "RMS_NORM",
     "RMS_NORM_BACK",
     "GROUP_NORM",
+    "L2_NORM",
 
     "MUL_MAT",
     "MUL_MAT_ID",
@@ -978,26 +979,22 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ADD_REL_POS",
     "RWKV_WKV6",
     "GATED_LINEAR_ATTN",
+    "RWKV_WKV7",
 
     "UNARY",
 
-    "MAP_UNARY",
-    "MAP_BINARY",
-
-    "MAP_CUSTOM1_F32",
-    "MAP_CUSTOM2_F32",
-    "MAP_CUSTOM3_F32",
-
     "MAP_CUSTOM1",
     "MAP_CUSTOM2",
     "MAP_CUSTOM3",
 
+    "CUSTOM",
+
     "CROSS_ENTROPY_LOSS",
     "CROSS_ENTROPY_LOSS_BACK",
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
+static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1027,6 +1024,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "rms_norm(x)",
     "rms_norm_back(x)",
     "group_norm(x)",
+    "l2_norm(x)",
 
     "X*Y",
     "X[i]*Y",
@@ -1076,26 +1074,22 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "add_rel_pos(x)",
     "rwkv_wkv6(k, v, r, tf, td, s)",
     "gated_linear_attn(k, v, q, gate, s)",
+    "rwkv_wkv7(r, w, k, v, a, b, s)",
 
     "unary(x)",
 
-    "f(x)",
-    "f(x,y)",
-
-    "custom_f32(x)",
-    "custom_f32(x,y)",
-    "custom_f32(x,y,z)",
+    "map_custom(x)",
+    "map_custom(x,y)",
+    "map_custom(x,y,z)",
 
     "custom(x)",
-    "custom(x,y)",
-    "custom(x,y,z)",
 
     "cross_entropy_loss(x,y)",
     "cross_entropy_loss_back(x,y)",
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
+static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1157,6 +1151,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
 }
 
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
+    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+        if (tensor->ne[i] <= 0) {
+            return 0;
+        }
+    }
+
     size_t nbytes;
     const size_t blck_size = ggml_blck_size(tensor->type);
     if (blck_size == 1) {
@@ -2334,6 +2334,7 @@ struct ggml_tensor * ggml_concat(
     struct ggml_tensor  * b,
     int                   dim) {
     GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
+    GGML_ASSERT(a->type == b->type);
 
     int64_t ne[GGML_MAX_DIMS];
     for (int d = 0; d < GGML_MAX_DIMS; ++d) {
@@ -2687,6 +2688,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
     return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
 }
 
+// ggml_l2_norm
+
+static struct ggml_tensor * ggml_l2_norm_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params_f32(result, 0, eps);
+
+    result->op     = GGML_OP_L2_NORM;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_l2_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps) {
+    return ggml_l2_norm_impl(ctx, a, eps, false);
+}
+
+struct ggml_tensor * ggml_l2_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps) {
+    return ggml_l2_norm_impl(ctx, a, eps, true);
+}
+
 // ggml_mul_mat
 
 static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -4144,7 +4176,8 @@ static struct ggml_tensor * ggml_upscale_impl(
         int                   ne0,
         int                   ne1,
         int                   ne2,
-        int                   ne3) {
+        int                   ne3,
+        enum ggml_scale_mode  mode) {
     GGML_ASSERT(a->ne[0] <= ne0);
     GGML_ASSERT(a->ne[1] <= ne1);
     GGML_ASSERT(a->ne[2] <= ne2);
@@ -4152,6 +4185,8 @@ static struct ggml_tensor * ggml_upscale_impl(
 
     struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
 
+    ggml_set_op_params_i32(result, 0, mode);
+
     result->op     = GGML_OP_UPSCALE;
     result->src[0] = a;
 
@@ -4161,8 +4196,9 @@ static struct ggml_tensor * ggml_upscale_impl(
 struct ggml_tensor * ggml_upscale(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   scale_factor) {
-    return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
+        int                   scale_factor,
+        enum ggml_scale_mode  mode) {
+    return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
 }
 
 struct ggml_tensor * ggml_upscale_ext(
@@ -4171,8 +4207,9 @@ struct ggml_tensor * ggml_upscale_ext(
         int                   ne0,
         int                   ne1,
         int                   ne2,
-        int                   ne3) {
-    return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
+        int                   ne3,
+        enum ggml_scale_mode  mode) {
+    return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
 }
 
 // ggml_pad
@@ -4354,7 +4391,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
     }
 
     // permute(0, 2, 1, 3)
-    int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+    int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     float params[] = { scale, max_bias, logit_softcap };
@@ -4740,6 +4777,54 @@ struct ggml_tensor * ggml_gated_linear_attn(
     return result;
 }
 
+// ggml_rwkv_wkv7
+
+struct ggml_tensor * ggml_rwkv_wkv7(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * r,
+        struct ggml_tensor  * w,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * state) {
+    GGML_ASSERT(ggml_is_contiguous(r));
+    GGML_ASSERT(ggml_is_contiguous(w));
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_is_contiguous(b));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S = k->ne[0];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
+    const int64_t n_seqs = state->ne[1];
+    {
+        GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
+        GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
+        GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
+        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    }
+
+    // concat output and new_state
+    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op     = GGML_OP_RWKV_WKV7;
+    result->src[0] = r;
+    result->src[1] = w;
+    result->src[2] = k;
+    result->src[3] = v;
+    result->src[4] = a;
+    result->src[5] = b;
+    result->src[6] = state;
+
+    return result;
+}
+
 // ggml_unary
 
 static struct ggml_tensor * ggml_unary_impl(
@@ -4773,179 +4858,6 @@ struct ggml_tensor * ggml_unary_inplace(
     return ggml_unary_impl(ctx, a, op, true);
 }
 
-// ggml_map_unary
-
-static struct ggml_tensor * ggml_map_unary_impl_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t   fun,
-        bool                         inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_UNARY;
-    result->src[0] = a;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_unary_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t   fun) {
-    return ggml_map_unary_impl_f32(ctx, a, fun, false);
-}
-
-struct ggml_tensor * ggml_map_unary_inplace_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t   fun) {
-    return ggml_map_unary_impl_f32(ctx, a, fun, true);
-}
-
-// ggml_map_binary
-
-static struct ggml_tensor * ggml_map_binary_impl_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t   fun,
-        bool                          inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_BINARY;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_binary_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t   fun) {
-    return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
-}
-
-struct ggml_tensor * ggml_map_binary_inplace_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t   fun) {
-    return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
-}
-
-// ggml_map_custom1_f32
-
-static struct ggml_tensor * ggml_map_custom1_impl_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_f32_t   fun,
-        bool                           inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_CUSTOM1_F32;
-    result->src[0] = a;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_custom1_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_f32_t   fun) {
-    return ggml_map_custom1_impl_f32(ctx, a, fun, false);
-}
-
-struct ggml_tensor * ggml_map_custom1_inplace_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_f32_t   fun) {
-    return ggml_map_custom1_impl_f32(ctx, a, fun, true);
-}
-
-// ggml_map_custom2_f32
-
-static struct ggml_tensor * ggml_map_custom2_impl_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_f32_t   fun,
-        bool                           inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_CUSTOM2_F32;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_custom2_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_f32_t   fun) {
-    return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
-}
-
-struct ggml_tensor * ggml_map_custom2_inplace_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_f32_t   fun) {
-    return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
-}
-
-// ggml_map_custom3_f32
-
-static struct ggml_tensor * ggml_map_custom3_impl_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_f32_t   fun,
-        bool                           inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_CUSTOM3_F32;
-    result->src[0] = a;
-    result->src[1] = b;
-    result->src[2] = c;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_custom3_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_f32_t   fun) {
-    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
-}
-
-struct ggml_tensor * ggml_map_custom3_inplace_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_f32_t   fun) {
-    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
-}
-
 // ggml_map_custom1
 
 static struct ggml_tensor * ggml_map_custom1_impl(
@@ -4964,7 +4876,7 @@ static struct ggml_tensor * ggml_map_custom1_impl(
         /*.n_tasks  =*/ n_tasks,
         /*.userdata =*/ userdata
     };
-    ggml_set_op_params(result, (const void *) ¶ms, sizeof(params));
+    ggml_set_op_params(result, ¶ms, sizeof(params));
 
     result->op     = GGML_OP_MAP_CUSTOM1;
     result->src[0] = a;
@@ -5009,7 +4921,7 @@ static struct ggml_tensor * ggml_map_custom2_impl(
         /*.n_tasks  =*/ n_tasks,
         /*.userdata =*/ userdata
     };
-    ggml_set_op_params(result, (const void *) ¶ms, sizeof(params));
+    ggml_set_op_params(result, ¶ms, sizeof(params));
 
     result->op     = GGML_OP_MAP_CUSTOM2;
     result->src[0] = a;
@@ -5058,7 +4970,7 @@ static struct ggml_tensor * ggml_map_custom3_impl(
         /*.n_tasks  =*/ n_tasks,
         /*.userdata =*/ userdata
     };
-    ggml_set_op_params(result, (const void *) ¶ms, sizeof(params));
+    ggml_set_op_params(result, ¶ms, sizeof(params));
 
     result->op     = GGML_OP_MAP_CUSTOM3;
     result->src[0] = a;
@@ -5090,6 +5002,66 @@ struct ggml_tensor * ggml_map_custom3_inplace(
     return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
 }
 
+struct ggml_tensor * ggml_custom_4d(
+        struct ggml_context * ctx,
+        enum ggml_type        type,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        int64_t               ne3,
+        struct ggml_tensor ** args,
+        int                   n_args,
+        ggml_custom_op_t      fun,
+        int                   n_tasks,
+        void                * userdata) {
+
+    GGML_ASSERT(n_args < GGML_MAX_SRC);
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
+
+    struct ggml_custom_op_params params = {
+        /*.fun      =*/ fun,
+        /*.n_tasks  =*/ n_tasks,
+        /*.userdata =*/ userdata
+    };
+    ggml_set_op_params(result, ¶ms, sizeof(params));
+
+    result->op = GGML_OP_CUSTOM;
+    for (int i = 0; i < n_args; i++) {
+        result->src[i] = args[i];
+    }
+
+    return result;
+}
+
+struct ggml_tensor * ggml_custom_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor ** args,
+        int                   n_args,
+        ggml_custom_op_t      fun,
+        int                   n_tasks,
+        void                * userdata) {
+
+    GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    struct ggml_custom_op_params params = {
+        /*.fun      =*/ fun,
+        /*.n_tasks  =*/ n_tasks,
+        /*.userdata =*/ userdata
+    };
+    ggml_set_op_params(result, ¶ms, sizeof(params));
+
+    result->op = GGML_OP_CUSTOM;
+    result->src[0] = a;
+    for (int i = 0; i < n_args; i++) {
+        result->src[i + 1] = args[i];
+    }
+
+    return result;
+}
 // ggml_cross_entropy_loss
 
 struct ggml_tensor * ggml_cross_entropy_loss(
diff --git a/ml/backend/ggml/ggml/src/gguf.cpp b/ml/backend/ggml/ggml/src/gguf.cpp
index f75b923fd..e45b453de 100644
--- a/ml/backend/ggml/ggml/src/gguf.cpp
+++ b/ml/backend/ggml/ggml/src/gguf.cpp
@@ -935,6 +935,7 @@ static void gguf_check_reserved_keys(const std::string & key, const T val) {
         if constexpr (std::is_same::value) {
             GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT " must be power of 2");
         } else {
+            GGML_UNUSED(val);
             GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT " must be type u32");
         }
     }