Skip to content

Commit 32dee05

Browse files
committed
add support for scaled fp8 tensors
1 parent 60efca9 commit 32dee05

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

model.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,12 @@ const char* unused_tensors[] = {
111111
"embedding_manager",
112112
"denoiser.sigmas",
113113
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
114+
"text_encoders.t5xxl.logit_scale", // only used during training
115+
"text_encoders.t5xxl.transformer.scaled_fp8",
114116
"text_encoders.qwen2vl.output.weight",
115117
"text_encoders.qwen2vl.lm_head.",
118+
"text_encoders.qwen2vl.logit_scale", // only used during training
119+
"text_encoders.qwen2vl.transformer.scaled_fp8",
116120
};
117121

118122
bool is_unused_tensor(std::string name) {
@@ -2088,6 +2092,26 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
20882092
std::atomic<bool> failed(false);
20892093
std::vector<std::thread> workers;
20902094

2095+
std::unordered_map<std::string, int> scale_idx;
2096+
std::unordered_map<std::string, int> scale_count;
2097+
for (int i = 0; i < file_tensors.size(); i++) {
2098+
const TensorStorage* tensor = file_tensors[i];
2099+
if (ends_with(tensor->name, ".scale_weight")) {
2100+
std::string new_name = tensor->name.substr(0, tensor->name.size() - strlen(".scale_weight")) + ".weight";
2101+
GGML_ASSERT(tensor->nelements() == 1 && tensor->type == GGML_TYPE_F32 && tensor->nbytes_to_read() == 4);
2102+
scale_idx[new_name] = i;
2103+
scale_count[new_name]++;
2104+
} else if (ends_with(tensor->name, ".weight") && (tensor->is_f8_e4m3fn || tensor->is_f8_e5m2)) {
2105+
scale_count[tensor->name]--;
2106+
}
2107+
}
2108+
for (auto& x : scale_count) {
2109+
if (x.second > 0) {
2110+
LOG_ERROR("f8 weight not found for scale_weight: '%s'", x.first.c_str());
2111+
return false;
2112+
}
2113+
}
2114+
20912115
for (int i = 0; i < n_threads; ++i) {
20922116
workers.emplace_back([&, file_path, is_zip]() {
20932117
std::ifstream file;
@@ -2118,6 +2142,15 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
21182142
} else {
21192143
fail |= !tensor_storage.read_data(buf, file);
21202144
}
2145+
float scale = 1;
2146+
if (scale_idx.count(tensor_storage.name)) {
2147+
const TensorStorage* tensor = file_tensors[scale_idx[tensor_storage.name]];
2148+
if (is_zip) {
2149+
fail |= !tensor->read_data(&scale, zip, memcpy_time_ms);
2150+
} else {
2151+
fail |= !tensor->read_data(&scale, file);
2152+
}
2153+
}
21212154
if (fail) {
21222155
failed = true;
21232156
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
@@ -2129,11 +2162,17 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
21292162
if (tensor_storage.is_f8_e4m3fn) {
21302163
for (int64_t i = count - 1; i >= 0; i--) {
21312164
static_cast<float*>(buf)[i] = f8_e4m3fn_to_f32(static_cast<uint8_t*>(buf)[i]);
2165+
if (scale != 1) {
2166+
static_cast<float*>(buf)[i] *= scale;
2167+
}
21322168
}
21332169
} else if (tensor_storage.is_f8_e5m2) {
21342170
for (int64_t i = count - 1; i >= 0; i--) {
21352171
static_cast<float*>(buf)[i] =
21362172
ggml_fp16_to_fp32(f8_e5m2_to_f16(static_cast<uint8_t*>(buf)[i]));
2173+
if (scale != 1) {
2174+
static_cast<float*>(buf)[i] *= scale;
2175+
}
21372176
}
21382177
} else if (tensor_storage.is_f64) {
21392178
for (int64_t i = 0; i < count; i++) {
@@ -2158,7 +2197,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
21582197
}
21592198

21602199
const TensorStorage& tensor_storage = *file_tensors[idx];
2161-
ggml_tensor* dst_tensor = nullptr;
2200+
if (ends_with(tensor_storage.name, ".scale_weight")) {
2201+
continue;
2202+
}
2203+
2204+
ggml_tensor* dst_tensor = nullptr;
21622205

21632206
if (!on_new_tensor_cb(tensor_storage, &dst_tensor)) {
21642207
LOG_WARN("process tensor failed: '%s'", tensor_storage.name.c_str());

0 commit comments

Comments
 (0)