From 322ad2a276682bd91ea23d0c154a4b5d3d25aee2 Mon Sep 17 00:00:00 2001 From: Haoyu Wang <894321963@163.com> Date: Wed, 23 Aug 2023 16:15:41 +0800 Subject: [PATCH] [BigFix] add is_load check when pipeline_para_size > 1 and int8_mode != 0 --- examples/pytorch/gpt/utils/gpt.py | 63 ++++++++++++++++--------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/examples/pytorch/gpt/utils/gpt.py b/examples/pytorch/gpt/utils/gpt.py index 20d90b45f..9af55db3f 100644 --- a/examples/pytorch/gpt/utils/gpt.py +++ b/examples/pytorch/gpt/utils/gpt.py @@ -403,38 +403,39 @@ def load_to_torch(file_path: str, is_load: bool): layer_num = self.layer_num if self.int8_mode != 0: for i in range(layer_num): - self.int8_w[i + 0 * layer_num], self.scale[i + 0 * - layer_num] = self.weight_transpose_calibrate_quantize(self.w[2 * layer_num + i]) - self.int8_w[i + 1 * layer_num], self.scale[i + 1 * - layer_num] = self.weight_transpose_calibrate_quantize(self.w[4 * layer_num + i]) - self.int8_w[i + 2 * layer_num], self.scale[i + 2 * - layer_num] = self.weight_transpose_calibrate_quantize(self.w[8 * layer_num + i]) - self.int8_w[i + 3 * layer_num], self.scale[i + 3 * - layer_num] = self.weight_transpose_calibrate_quantize(self.w[10 * layer_num + i]) - - # We clear the original weights since they are no longer needed - if self.int8_mode == 1: - self.w[2 * layer_num + i] = torch.empty(0).to(str_type_map[self.inference_data_type]) - self.w[4 * layer_num + i] = torch.empty(0).to(str_type_map[self.inference_data_type]) - self.w[8 * layer_num + i] = torch.empty(0).to(str_type_map[self.inference_data_type]) - self.w[10 * layer_num + i] = torch.empty(0).to(str_type_map[self.inference_data_type]) - - if self.has_adapters: - self.int8_w[i + 4 * layer_num], self.scale[i + 4 * layer_num] = self.weight_transpose_calibrate_quantize( - self.w[12 * layer_num + i + self.adapter_offset]) - self.int8_w[i + 5 * layer_num], self.scale[i + 5 * layer_num] = self.weight_transpose_calibrate_quantize( - self.w[14 * layer_num + i + self.adapter_offset]) - self.int8_w[i + 6 * layer_num], self.scale[i + 6 * layer_num] = self.weight_transpose_calibrate_quantize( - self.w[16 * layer_num + i + self.adapter_offset]) - self.int8_w[i + 7 * layer_num], self.scale[i + 7 * layer_num] = self.weight_transpose_calibrate_quantize( - self.w[18 * layer_num + i + self.adapter_offset]) - - # Similar to above: + if is_load(i): + self.int8_w[i + 0 * layer_num], self.scale[i + 0 * + layer_num] = self.weight_transpose_calibrate_quantize(self.w[2 * layer_num + i]) + self.int8_w[i + 1 * layer_num], self.scale[i + 1 * + layer_num] = self.weight_transpose_calibrate_quantize(self.w[4 * layer_num + i]) + self.int8_w[i + 2 * layer_num], self.scale[i + 2 * + layer_num] = self.weight_transpose_calibrate_quantize(self.w[8 * layer_num + i]) + self.int8_w[i + 3 * layer_num], self.scale[i + 3 * + layer_num] = self.weight_transpose_calibrate_quantize(self.w[10 * layer_num + i]) + + # We clear the original weights since they are no longer needed if self.int8_mode == 1: - self.w[12 * layer_num + i + self.adapter_offset] = torch.empty(0).to(str_type_map[self.inference_data_type]) - self.w[14 * layer_num + i + self.adapter_offset] = torch.empty(0).to(str_type_map[self.inference_data_type]) - self.w[16 * layer_num + i + self.adapter_offset] = torch.empty(0).to(str_type_map[self.inference_data_type]) - self.w[18 * layer_num + i + self.adapter_offset] = torch.empty(0).to(str_type_map[self.inference_data_type]) + self.w[2 * layer_num + i] = torch.empty(0).to(str_type_map[self.inference_data_type]) + self.w[4 * layer_num + i] = torch.empty(0).to(str_type_map[self.inference_data_type]) + self.w[8 * layer_num + i] = torch.empty(0).to(str_type_map[self.inference_data_type]) + self.w[10 * layer_num + i] = torch.empty(0).to(str_type_map[self.inference_data_type]) + + if self.has_adapters: + self.int8_w[i + 4 * layer_num], self.scale[i + 4 * layer_num] = self.weight_transpose_calibrate_quantize( + self.w[12 * layer_num + i + self.adapter_offset]) + self.int8_w[i + 5 * layer_num], self.scale[i + 5 * layer_num] = self.weight_transpose_calibrate_quantize( + self.w[14 * layer_num + i + self.adapter_offset]) + self.int8_w[i + 6 * layer_num], self.scale[i + 6 * layer_num] = self.weight_transpose_calibrate_quantize( + self.w[16 * layer_num + i + self.adapter_offset]) + self.int8_w[i + 7 * layer_num], self.scale[i + 7 * layer_num] = self.weight_transpose_calibrate_quantize( + self.w[18 * layer_num + i + self.adapter_offset]) + + # Similar to above: + if self.int8_mode == 1: + self.w[12 * layer_num + i + self.adapter_offset] = torch.empty(0).to(str_type_map[self.inference_data_type]) + self.w[14 * layer_num + i + self.adapter_offset] = torch.empty(0).to(str_type_map[self.inference_data_type]) + self.w[16 * layer_num + i + self.adapter_offset] = torch.empty(0).to(str_type_map[self.inference_data_type]) + self.w[18 * layer_num + i + self.adapter_offset] = torch.empty(0).to(str_type_map[self.inference_data_type]) return True