Skip to content

Commit 42ae14b

Browse files
committed
add support for scaled fp8 tensors
1 parent 56d850d commit 42ae14b

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

model.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,12 @@ const char* unused_tensors[] = {
111111
"embedding_manager",
112112
"denoiser.sigmas",
113113
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
114+
"text_encoders.t5xxl.logit_scale", // only used during training
115+
"text_encoders.t5xxl.transformer.scaled_fp8",
114116
"text_encoders.qwen2vl.output.weight",
115117
"text_encoders.qwen2vl.lm_head.",
118+
"text_encoders.qwen2vl.logit_scale", // only used during training
119+
"text_encoders.qwen2vl.transformer.scaled_fp8",
116120
};
117121

118122
bool is_unused_tensor(std::string name) {
@@ -2084,6 +2088,26 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
20842088
std::atomic<bool> failed(false);
20852089
std::vector<std::thread> workers;
20862090

2091+
std::unordered_map<std::string, int> scale_idx;
2092+
std::unordered_map<std::string, int> scale_count;
2093+
for (int i = 0; i < file_tensors.size(); i++) {
2094+
const TensorStorage* tensor = file_tensors[i];
2095+
if (ends_with(tensor->name, ".scale_weight")) {
2096+
std::string new_name = tensor->name.substr(0, tensor->name.size() - strlen(".scale_weight")) + ".weight";
2097+
GGML_ASSERT(tensor->nelements() == 1 && tensor->type == GGML_TYPE_F32 && tensor->nbytes_to_read() == 4);
2098+
scale_idx[new_name] = i;
2099+
scale_count[new_name]++;
2100+
} else if (ends_with(tensor->name, ".weight") && (tensor->is_f8_e4m3fn || tensor->is_f8_e5m2)) {
2101+
scale_count[tensor->name]--;
2102+
}
2103+
}
2104+
for (auto& x : scale_count) {
2105+
if (x.second > 0) {
2106+
LOG_ERROR("f8 weight not found for scale_weight: '%s'", x.first.c_str());
2107+
return false;
2108+
}
2109+
}
2110+
20872111
for (int i = 0; i < n_threads; ++i) {
20882112
workers.emplace_back([&, file_path, is_zip]() {
20892113
std::ifstream file;
@@ -2114,6 +2138,15 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
21142138
} else {
21152139
fail |= !tensor_storage.read_data(buf, file);
21162140
}
2141+
float scale = 1;
2142+
if (scale_idx.count(tensor_storage.name)) {
2143+
const TensorStorage* tensor = file_tensors[scale_idx[tensor_storage.name]];
2144+
if (is_zip) {
2145+
fail |= !tensor->read_data(&scale, zip, memcpy_time_ms);
2146+
} else {
2147+
fail |= !tensor->read_data(&scale, file);
2148+
}
2149+
}
21172150
if (fail) {
21182151
failed = true;
21192152
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
@@ -2125,11 +2158,17 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
21252158
if (tensor_storage.is_f8_e4m3fn) {
21262159
for (int64_t i = count - 1; i >= 0; i--) {
21272160
static_cast<float*>(buf)[i] = f8_e4m3fn_to_f32(static_cast<uint8_t*>(buf)[i]);
2161+
if (scale != 1) {
2162+
static_cast<float*>(buf)[i] *= scale;
2163+
}
21282164
}
21292165
} else if (tensor_storage.is_f8_e5m2) {
21302166
for (int64_t i = count - 1; i >= 0; i--) {
21312167
static_cast<float*>(buf)[i] =
21322168
ggml_fp16_to_fp32(f8_e5m2_to_f16(static_cast<uint8_t*>(buf)[i]));
2169+
if (scale != 1) {
2170+
static_cast<float*>(buf)[i] *= scale;
2171+
}
21332172
}
21342173
} else if (tensor_storage.is_f64) {
21352174
for (int64_t i = 0; i < count; i++) {
@@ -2154,7 +2193,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
21542193
}
21552194

21562195
const TensorStorage& tensor_storage = *file_tensors[idx];
2157-
ggml_tensor* dst_tensor = nullptr;
2196+
if (ends_with(tensor_storage.name, ".scale_weight")) {
2197+
continue;
2198+
}
2199+
2200+
ggml_tensor* dst_tensor = nullptr;
21582201

21592202
if (!on_new_tensor_cb(tensor_storage, &dst_tensor)) {
21602203
LOG_WARN("process tensor failed: '%s'", tensor_storage.name.c_str());

0 commit comments

Comments
 (0)