Skip to content

Commit a0ea2b0

Browse files
committed
Move conversion script into exllamav2 package
1 parent 6509e90 commit a0ea2b0

16 files changed

+324
-323
lines changed

convert.py

+1-313
Original file line numberDiff line numberDiff line change
@@ -1,313 +1 @@
1-
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer
2-
from exllamav2.architecture import RopeStyle
3-
import argparse, os, shutil
4-
import sys
5-
import json
6-
from conversion.tokenize import tokenize
7-
from conversion.measure import embeddings, measure_quant
8-
from conversion.quantize import quant
9-
from conversion.optimize import optimize
10-
from conversion.compile import compile_model
11-
from conversion.qparams import qparams_headoptions
12-
import torch
13-
14-
parser = argparse.ArgumentParser(description = "Convert model to ExLlamaV2")
15-
parser.add_argument("-i", "--in_dir", type = str, help = "Input directory", default = "")
16-
parser.add_argument("-o", "--out_dir", type = str, help = "Output (working) directory")
17-
parser.add_argument("-res", "--resume", action = "store_true", help = "Resume job from specified output directory (without specifying other options)")
18-
parser.add_argument("-nr", "--no_resume", action = "store_true", help = "Do not resume an interrupted job (deletes all files in the output directory)")
19-
parser.add_argument("-cf", "--compile_full", type = str, help = "Output folder for compiled model with all config/tokenizer files")
20-
parser.add_argument("-c", "--cal_dataset", type = str, help = "Calibration dataset (.parquet file)")
21-
parser.add_argument("-b", "--bits", type = float, default = 4.125, help = "Target bits per weight")
22-
parser.add_argument("-ss", "--shard_size", type = float, help = "Max shard size in MB (default: 8192)", default = 8192)
23-
parser.add_argument("-rs", "--rope_scale", type = float, help = "RoPE scaling factor")
24-
parser.add_argument("-ra", "--rope_alpha", type = float, help = "RoPE alpha value (NTK)")
25-
parser.add_argument("-hb", "--head_bits", type = int, default = 6, help = "Target bits per weight (head layer)")
26-
parser.add_argument("-om", "--output_measurement", type = str, help = "Only perform measurement pass, then save measurement to the specified file")
27-
parser.add_argument("-m", "--measurement", type = str, help = "Reuse previous measurement")
28-
parser.add_argument("-r", "--dataset_rows", type = int, default = 100, help = "Number of rows to apply from dataset")
29-
parser.add_argument("-mr", "--measurement_rows", type = int, default = 16, help = "Number of rows to apply from dataset when measuring")
30-
parser.add_argument("-l", "--length", type = int, default = 2048, help = "Max no. tokens per sample")
31-
parser.add_argument("-ml", "--measurement_length", type = int, default = 2048, help = "Max no. tokens per sample when measuring")
32-
parser.add_argument("-so", "--status_output", action = "store_true", help = "Include machine-parseable status updates in console output")
33-
parser.add_argument("-hsol", "--hidden_state_offload_layers", type = int, default = 0, help = "Number of hidden/target states to keep in VRAM. Speed-up but increases VRAM usage")
34-
35-
args = parser.parse_args()
36-
37-
torch.set_printoptions(precision = 7, sci_mode = False, linewidth = 200)
38-
39-
# Check some args
40-
41-
resuming = False
42-
if args.out_dir:
43-
if not args.no_resume:
44-
if os.path.exists(os.path.join(args.out_dir, "job_new.json")):
45-
resuming = True
46-
else:
47-
print(" ## Please specify output/working directory (-o, --out_dir)")
48-
sys.exit()
49-
50-
if not args.in_dir and not resuming:
51-
print(" ## Please specify input model directory (-i, --in_dir)")
52-
sys.exit()
53-
54-
if args.length > 2048 or args.measurement_length > 2048:
55-
print(" !! Warning: calibration rows > 2048 tokens may result in excessive VRAM use")
56-
57-
if not args.head_bits in qparams_headoptions:
58-
print(f" ## Error: {args.head_bits} is not a supported option for head layer bitrate")
59-
sys.exit()
60-
61-
if args.output_measurement is not None and args.compile_full is not None:
62-
print(" ## Conflicting options: --output_measurement and --compile_full")
63-
sys.exit()
64-
65-
if args.bits < 2 or args.bits > 8:
66-
print(f" !! Warning: target bitrate {args.bits} will likely not be attainable")
67-
68-
if not os.path.exists(args.out_dir):
69-
print(f" ## Error: Directory not found: {args.out_dir}")
70-
sys.exit()
71-
72-
# Create job
73-
74-
def save_job():
75-
global job_file, job
76-
with open(job_file, "w", encoding = "utf8") as f:
77-
f.write(json.dumps(job, indent = 4))
78-
79-
job_file = os.path.join(args.out_dir, "job_new.json")
80-
81-
if args.no_resume or not os.path.exists(job_file):
82-
83-
print(f" -- Beginning new job")
84-
if len(os.listdir(args.out_dir)) != 0:
85-
print(f" !! Warning: Output directory is not empty: {args.out_dir}")
86-
87-
if args.no_resume:
88-
print(f" !! Cleaning output directory: {args.out_dir}")
89-
for filename in os.listdir(args.out_dir):
90-
file_path = os.path.join(args.out_dir, filename)
91-
if os.path.isfile(file_path):
92-
os.unlink(file_path)
93-
elif os.path.isdir(file_path):
94-
shutil.rmtree(file_path)
95-
96-
output_measurement = args.output_measurement
97-
if output_measurement is not None:
98-
if os.path.isdir(output_measurement):
99-
output_measurement = os.path.join(output_measurement, "measurement.json")
100-
101-
job = {"in_dir": args.in_dir,
102-
"out_dir": args.out_dir,
103-
"cal_dataset": args.cal_dataset,
104-
"bits": args.bits,
105-
"dataset_rows": args.dataset_rows,
106-
"measurement_rows": args.measurement_rows,
107-
"length": args.length,
108-
"measurement_length": args.measurement_length,
109-
"head_bits": args.head_bits,
110-
"shard_size": args.shard_size if args.shard_size > 0 else 1024 ** 3, # 1 PB = unlimited,
111-
"compile_full": args.compile_full,
112-
"rope_scale": args.rope_scale,
113-
"rope_alpha": args.rope_alpha,
114-
"output_measurement": output_measurement,
115-
"progress": "begin"}
116-
117-
if args.measurement is not None:
118-
with open(args.measurement, "r", encoding = "utf8") as f:
119-
imp_measurement = json.load(f)
120-
job["measurement"] = imp_measurement["measurement"]
121-
job["last_module_idx"] = imp_measurement["last_module_idx"]
122-
job["reuse_measurement"] = args.measurement
123-
124-
# Resume existing job
125-
126-
if args.no_resume or not os.path.exists(job_file):
127-
pass
128-
129-
else:
130-
print(f" -- Resuming job")
131-
if args.in_dir:
132-
print(f" !! Note: Overriding options with settings from existing job")
133-
134-
with open(job_file, "r", encoding = "utf8") as f:
135-
resume_job = json.load(f)
136-
137-
# Override keys in existing job
138-
del resume_job["out_dir"]
139-
140-
job.update(resume_job)
141-
if "invalid" in job:
142-
print(" ** Error: Corrupted job")
143-
sys.exit()
144-
145-
if job["progress"] == "finished":
146-
print(" !! Job is already finished")
147-
sys.exit()
148-
149-
# Feedback
150-
151-
print(f" -- Input: {job['in_dir']}")
152-
print(f" -- Output: {job['out_dir']}")
153-
if job.get("cal_dataset"):
154-
print(f" -- Calibration dataset: {job['cal_dataset']}, {job['dataset_rows']} / {job['measurement_rows']} rows, {job['length']} tokens per sample")
155-
else:
156-
print(f" -- Using default calibration dataset")
157-
if job["output_measurement"] is None:
158-
print(f" -- Target bits per weight: {job['bits']} (decoder), {job['head_bits']} (head)")
159-
print(f" -- Max shard size: {job['shard_size']} MB")
160-
else:
161-
print(f" -- Measurement will be saved to {job['output_measurement']}")
162-
print(f" !! Conversion script will end after measurement pass")
163-
164-
if job['rope_scale']: print(f" -- RoPE scale: {job['rope_scale']:.2f}")
165-
if job['rope_alpha']: print(f" -- RoPE alpha: {job['rope_alpha']:.2f}")
166-
167-
# Make sure subfolders exist
168-
169-
if job.get("compile_full"):
170-
print(f" -- Full model will be compiled to: {job['compile_full']}")
171-
if os.path.exists(job["compile_full"]):
172-
if not os.path.isdir(job["compile_full"]):
173-
print(f" ## Error: Output path {job['compile_full']} exists but is not a directory")
174-
sys.exit()
175-
if len(os.listdir(job["compile_full"])) > 0:
176-
print(f" !! Warning: Output path {job['compile_full']} exists but is not empty")
177-
178-
out_tensor_dir = os.path.join(job["out_dir"], "out_tensor")
179-
if not os.path.exists(out_tensor_dir):
180-
os.makedirs(out_tensor_dir)
181-
182-
# Create config
183-
184-
config = ExLlamaV2Config()
185-
config.model_dir = job['in_dir']
186-
config.qkv_embed = False
187-
config.prepare()
188-
189-
# Tokenizer
190-
191-
tokenizer = ExLlamaV2Tokenizer(config)
192-
193-
# Set scaling for input model
194-
195-
if job["rope_scale"] is not None: config.scale_pos_emb = job["rope_scale"]
196-
if job["rope_alpha"] is not None: config.scale_alpha_value = job["rope_alpha"]
197-
198-
# Create model without loading weights
199-
200-
model = ExLlamaV2(config)
201-
model.load(lazy = True)
202-
203-
# Limit context length if necessary
204-
205-
if model.config.arch.rope_style == RopeStyle.NONE:
206-
max_ctx = model.config.max_seq_len
207-
if job["length"] > max_ctx:
208-
print (f" !! Warning: Reducing calibration length to model max context: {max_ctx}")
209-
job["length"] = max_ctx
210-
if job["measurement_length"] > max_ctx:
211-
print (f" !! Warning: Reducing measurement calibration length to model max context: {max_ctx}")
212-
job["measurement_length"] = max_ctx
213-
214-
# Overridable settings
215-
216-
job["status_output"] = args.status_output
217-
218-
# Do the things
219-
220-
save_job()
221-
222-
while True:
223-
224-
progress = job["progress"]
225-
226-
if progress == "begin":
227-
228-
if "reuse_measurement" in job:
229-
230-
print(f" -- Reusing measurement: {job['reuse_measurement']}")
231-
job["progress"] = "optimize"
232-
save_job()
233-
234-
else:
235-
236-
print(f" -- Tokenizing samples (measurement)...")
237-
tokenize(job, save_job, tokenizer, measure = True)
238-
job["progress"] = "initial_embeddings"
239-
save_job()
240-
241-
if progress == "initial_embeddings":
242-
243-
print(f" -- Token embeddings (measurement)...")
244-
embeddings(job, save_job, model)
245-
job["progress"] = "measure_quant"
246-
save_job()
247-
248-
if progress == "measure_quant":
249-
print(f" -- Measuring quantization impact...")
250-
251-
model.unload()
252-
config.max_output_len = 16
253-
model = ExLlamaV2(config)
254-
model.load(lazy = True)
255-
256-
status = measure_quant(job, save_job, model, args.hidden_state_offload_layers) # capturing the graceful exits
257-
if status == "interrupted":
258-
print("Process interrupted. Exiting gracefully.")
259-
save_job()
260-
sys.exit(1)
261-
if job["output_measurement"] is None:
262-
job["progress"] = "optimize"
263-
else:
264-
job["progress"] = "finished"
265-
save_job()
266-
267-
model.unload()
268-
config.max_output_len = None
269-
model = ExLlamaV2(config)
270-
model.load(lazy = True)
271-
272-
if progress == "optimize":
273-
274-
print(f" -- Optimizing...")
275-
optimize(job, save_job, model)
276-
job["progress"] = "tokens_cal"
277-
save_job()
278-
279-
if progress == "tokens_cal":
280-
281-
print(f" -- Tokenizing samples...")
282-
tokenize(job, save_job, tokenizer)
283-
job["progress"] = "embeddings"
284-
save_job()
285-
286-
if progress == "embeddings":
287-
print(f" -- Token embeddings again...")
288-
embeddings(job, save_job, model)
289-
job["progress"] = "quant"
290-
save_job()
291-
292-
if progress == "quant":
293-
294-
print(f" -- Quantizing...")
295-
quant(job, save_job, model)
296-
job["progress"] = "compile"
297-
save_job()
298-
299-
if progress == "compile":
300-
301-
print(f" -- Compiling output file...")
302-
compile_model(job, save_job, model)
303-
job["progress"] = "finished"
304-
save_job()
305-
306-
if progress == "finished": break
307-
308-
print(f" -- Finished")
309-
310-
311-
312-
313-
1+
import exllamav2.conversion.convert_exl2
File renamed without changes.
File renamed without changes.

conversion/compile.py exllamav2/conversion/compile.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os, glob, shutil, json
1818
from safetensors import safe_open
1919
from safetensors.torch import save_file
20-
from conversion.bot_status import print_stage
20+
from exllamav2.conversion.bot_status import print_stage
2121

2222
def _tsize(t):
2323

0 commit comments

Comments
 (0)