sort devices by type

This commit is contained in:
Michael Yang 2025-01-14 12:04:59 -08:00
parent 5827999e9e
commit 800e21d060
2 changed files with 95 additions and 8 deletions

View File

@ -0,0 +1,82 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Tue, 14 Jan 2025 12:01:24 -0800
Subject: [PATCH] sort devices by score
---
ggml/src/ggml-backend-reg.cpp | 21 +++++++++++++--------
1 file changed, 13 insertions(+), 8 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 899d16f2..ac5cda07 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -150,7 +150,7 @@ struct ggml_backend_reg_entry {
struct ggml_backend_registry {
std::vector<ggml_backend_reg_entry> backends;
- std::vector<ggml_backend_dev_t> devices;
+ std::vector<std::pair<ggml_backend_dev_t, int>> devices;
ggml_backend_registry() {
#ifdef GGML_USE_CUDA
@@ -195,7 +195,7 @@ struct ggml_backend_registry {
}
}
- void register_backend(ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) {
+ void register_backend(ggml_backend_reg_t reg, int score = -1, dl_handle_ptr handle = nullptr) {
if (!reg) {
return;
}
@@ -206,15 +206,15 @@ struct ggml_backend_registry {
#endif
backends.push_back({ reg, std::move(handle) });
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
- register_device(ggml_backend_reg_dev_get(reg, i));
+ register_device(ggml_backend_reg_dev_get(reg, i), score);
}
}
- void register_device(ggml_backend_dev_t device) {
+ void register_device(ggml_backend_dev_t device, int score = -1) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
#endif
- devices.push_back(device);
+ devices.push_back({device, score});
}
ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
@@ -257,7 +257,7 @@ struct ggml_backend_registry {
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
- register_backend(reg, std::move(handle));
+ register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
return reg;
}
@@ -280,7 +280,7 @@ struct ggml_backend_registry {
// remove devices
devices.erase(
std::remove_if(devices.begin(), devices.end(),
- [reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }),
+ [reg](std::pair<ggml_backend_dev_t, int> dev) { return ggml_backend_dev_backend_reg(dev.first) == reg; }),
devices.end());
// remove backend
@@ -338,7 +338,12 @@ size_t ggml_backend_dev_count() {
ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
GGML_ASSERT(index < ggml_backend_dev_count());
- return get_reg().devices[index];
+ auto devices = get_reg().devices;
+ if (!std::is_heap(devices.begin(), devices.end())) {
+ std::make_heap(devices.begin(), devices.end(), [](const auto & a, const auto & b) { return a.second < b.second; });
+ }
+
+ return devices[index].first;
}
ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {

View File

@ -150,7 +150,7 @@ struct ggml_backend_reg_entry {
struct ggml_backend_registry { struct ggml_backend_registry {
std::vector<ggml_backend_reg_entry> backends; std::vector<ggml_backend_reg_entry> backends;
std::vector<ggml_backend_dev_t> devices; std::vector<std::pair<ggml_backend_dev_t, int>> devices;
ggml_backend_registry() { ggml_backend_registry() {
#ifdef GGML_USE_CUDA #ifdef GGML_USE_CUDA
@ -195,7 +195,7 @@ struct ggml_backend_registry {
} }
} }
void register_backend(ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) { void register_backend(ggml_backend_reg_t reg, int score = -1, dl_handle_ptr handle = nullptr) {
if (!reg) { if (!reg) {
return; return;
} }
@ -206,15 +206,15 @@ struct ggml_backend_registry {
#endif #endif
backends.push_back({ reg, std::move(handle) }); backends.push_back({ reg, std::move(handle) });
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
register_device(ggml_backend_reg_dev_get(reg, i)); register_device(ggml_backend_reg_dev_get(reg, i), score);
} }
} }
void register_device(ggml_backend_dev_t device) { void register_device(ggml_backend_dev_t device, int score = -1) {
#ifndef NDEBUG #ifndef NDEBUG
GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device)); GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
#endif #endif
devices.push_back(device); devices.push_back({device, score});
} }
ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) { ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
@ -257,7 +257,7 @@ struct ggml_backend_registry {
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str()); GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
register_backend(reg, std::move(handle)); register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
return reg; return reg;
} }
@ -280,7 +280,7 @@ struct ggml_backend_registry {
// remove devices // remove devices
devices.erase( devices.erase(
std::remove_if(devices.begin(), devices.end(), std::remove_if(devices.begin(), devices.end(),
[reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }), [reg](std::pair<ggml_backend_dev_t, int> dev) { return ggml_backend_dev_backend_reg(dev.first) == reg; }),
devices.end()); devices.end());
// remove backend // remove backend
@ -338,7 +338,12 @@ size_t ggml_backend_dev_count() {
ggml_backend_dev_t ggml_backend_dev_get(size_t index) { ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
GGML_ASSERT(index < ggml_backend_dev_count()); GGML_ASSERT(index < ggml_backend_dev_count());
return get_reg().devices[index]; auto devices = get_reg().devices;
if (!std::is_heap(devices.begin(), devices.end())) {
std::make_heap(devices.begin(), devices.end(), [](const auto & a, const auto & b) { return a.second < b.second; });
}
return devices[index].first;
} }
ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) { ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {