Skip to content

Commit 2b1bc06

Browse files
bssrdfbssrdf
and
bssrdf
authored
feat: add PhotoMaker Version 2 support (leejet#358)
* first attempt at updating to photomaker v2 * continue adding photomaker v2 modules * finishing the last few pieces for photomaker v2; id_embeds need to be done by a manual step and pass as an input file * added a name converter for Photomaker V2; build ok * more debugging underway * failing at cuda mat_mul * updated chunk_half to be more efficient; redo feedforward * fixed a bug: carefully using ggml_view_4d to get chunks of a tensor; strides need to be recalculated or set properly; still failing at soft_max cuda op * redo weight calculation and weight*v * fixed a bug now Photomaker V2 kinds of working * add python script for face detection (Photomaker V2 needs) * updated readme for photomaker * fixed a bug causing PMV1 crashing; both V1 and V2 work * fixed clean_input_ids for PMV2 * fixed a double counting bug in tokenize_with_trigger_token * updated photomaker readme * removed some commented code * improved reconstructing class word free prompt * changed reading id_embed to raw binary using existing load tensor function; this is more efficient than using model load and also makes it easier to work with sd server * minor clean up --------- Co-authored-by: bssrdf <[email protected]>
1 parent b99cbfe commit 2b1bc06

11 files changed

+845
-57
lines changed

clip.hpp

+22-5
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,14 @@ class CLIPTokenizer {
343343
}
344344
}
345345

346+
std::string clean_up_tokenization(std::string &text){
347+
348+
std::regex pattern(R"( ,)");
349+
// Replace " ," with ","
350+
std::string result = std::regex_replace(text, pattern, ",");
351+
return result;
352+
}
353+
346354
std::string decode(const std::vector<int>& tokens) {
347355
std::string text = "";
348356
for (int t : tokens) {
@@ -351,8 +359,12 @@ class CLIPTokenizer {
351359
std::u32string ts = decoder[t];
352360
// printf("%d, %s \n", t, utf32_to_utf8(ts).c_str());
353361
std::string s = utf32_to_utf8(ts);
354-
if (s.length() >= 4 && ends_with(s, "</w>")) {
355-
text += " " + s.replace(s.length() - 4, s.length() - 1, "");
362+
if (s.length() >= 4 ){
363+
if(ends_with(s, "</w>")) {
364+
text += s.replace(s.length() - 4, s.length() - 1, "") + " ";
365+
}else{
366+
text += s;
367+
}
356368
} else {
357369
text += " " + s;
358370
}
@@ -364,6 +376,7 @@ class CLIPTokenizer {
364376

365377
// std::string s((char *)bytes.data());
366378
// std::string s = "";
379+
text = clean_up_tokenization(text);
367380
return trim(text);
368381
}
369382

@@ -755,7 +768,8 @@ class CLIPVisionModel : public GGMLBlock {
755768
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
756769
}
757770

758-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true) {
771+
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values,
772+
bool return_pooled = true) {
759773
// pixel_values: [N, num_channels, image_size, image_size]
760774
auto embeddings = std::dynamic_pointer_cast<CLIPVisionEmbeddings>(blocks["embeddings"]);
761775
auto pre_layernorm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_layernorm"]);
@@ -765,14 +779,17 @@ class CLIPVisionModel : public GGMLBlock {
765779
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
766780
x = pre_layernorm->forward(ctx, x);
767781
x = encoder->forward(ctx, x, -1, false);
782+
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
783+
auto last_hidden_state = x;
768784
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
769785

770-
GGML_ASSERT(x->ne[3] == 1);
786+
GGML_ASSERT(x->ne[3] == 1);
771787
if (return_pooled) {
772788
ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0));
773789
return pooled; // [N, hidden_size]
774790
} else {
775-
return x; // [N, n_token, hidden_size]
791+
// return x; // [N, n_token, hidden_size]
792+
return last_hidden_state; // [N, n_token, hidden_size]
776793
}
777794
}
778795
};

conditioner.hpp

+15-9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "clip.hpp"
55
#include "t5.hpp"
66

7+
78
struct SDCondition {
89
struct ggml_tensor* c_crossattn = NULL; // aka context
910
struct ggml_tensor* c_vector = NULL; // aka y
@@ -44,6 +45,7 @@ struct Conditioner {
4445
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
4546
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
4647
SDVersion version = VERSION_SD1;
48+
PMVersion pm_version = VERSION_1;
4749
CLIPTokenizer tokenizer;
4850
ggml_type wtype;
4951
std::shared_ptr<CLIPTextModelRunner> text_model;
@@ -59,8 +61,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
5961
ggml_type wtype,
6062
const std::string& embd_dir,
6163
SDVersion version = VERSION_SD1,
64+
PMVersion pv = VERSION_1,
6265
int clip_skip = -1)
63-
: version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
66+
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
6467
if (clip_skip <= 0) {
6568
clip_skip = 1;
6669
if (version == VERSION_SD2 || version == VERSION_SDXL) {
@@ -159,7 +162,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
159162
tokenize_with_trigger_token(std::string text,
160163
int num_input_imgs,
161164
int32_t image_token,
162-
bool padding = false) {
165+
bool padding = false){
163166
return tokenize_with_trigger_token(text, num_input_imgs, image_token,
164167
text_model->model.n_token, padding);
165168
}
@@ -268,7 +271,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
268271
std::vector<int> clean_input_ids_tmp;
269272
for (uint32_t i = 0; i < class_token_index[0]; i++)
270273
clean_input_ids_tmp.push_back(clean_input_ids[i]);
271-
for (uint32_t i = 0; i < num_input_imgs; i++)
274+
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs); i++)
272275
clean_input_ids_tmp.push_back(class_token);
273276
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
274277
clean_input_ids_tmp.push_back(clean_input_ids[i]);
@@ -279,13 +282,16 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
279282
tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
280283
weights.insert(weights.end(), clean_input_ids.size(), curr_weight);
281284
}
282-
tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID);
283-
weights.insert(weights.begin(), 1.0);
285+
// BUG!! double couting, pad_tokens will add BOS at the beginning
286+
// tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID);
287+
// weights.insert(weights.begin(), 1.0);
284288

285289
tokenizer.pad_tokens(tokens, weights, max_length, padding);
286-
290+
int offset = pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs;
287291
for (uint32_t i = 0; i < tokens.size(); i++) {
288-
if (class_idx + 1 <= i && i < class_idx + 1 + num_input_imgs)
292+
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
293+
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
294+
// hardcode for now
289295
class_token_mask.push_back(true);
290296
else
291297
class_token_mask.push_back(false);
@@ -530,7 +536,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
530536
int height,
531537
int num_input_imgs,
532538
int adm_in_channels = -1,
533-
bool force_zero_embeddings = false) {
539+
bool force_zero_embeddings = false){
534540
auto image_tokens = convert_token_to_id(trigger_word);
535541
// if(image_tokens.size() == 1){
536542
// printf(" image token id is: %d \n", image_tokens[0]);
@@ -958,7 +964,7 @@ struct SD3CLIPEmbedder : public Conditioner {
958964
int height,
959965
int num_input_imgs,
960966
int adm_in_channels = -1,
961-
bool force_zero_embeddings = false) {
967+
bool force_zero_embeddings = false){
962968
GGML_ASSERT(0 && "Not implemented yet!");
963969
}
964970

docs/photo_maker.md

+23-1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,26 @@ Example:
2929

3030
```bash
3131
bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/photomaker_examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png
32-
```
32+
```
33+
34+
## PhotoMaker Version 2
35+
36+
[PhotoMaker Version 2 (PMV2)](https://github.com/TencentARC/PhotoMaker/blob/main/README_pmv2.md) has some key improvements. Unfortunately it has a very heavy dependency which makes running it a bit involved in ```SD.cpp```.
37+
38+
Running PMV2 is now a two-step process:
39+
40+
- Run a python script ```face_detect.py``` to obtain **id_embeds** for the given input images
41+
```
42+
python face_detect.py input_image_dir
43+
```
44+
An ```id_embeds.safetensors``` file will be generated in ```input_images_dir```
45+
46+
**Note: this step is only needed to run once; the same ```id_embeds``` can be reused**
47+
48+
- Run the same command as in version 1 but replacing ```photomaker-v1.safetensors``` with ```photomaker-v2.safetensors```.
49+
50+
You can download ```photomaker-v2.safetensors``` from [here](https://huggingface.co/bssrdf/PhotoMakerV2)
51+
52+
- All the command line parameters from Version 1 remain the same for Version 2
53+
54+

face_detect.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
import sys
3+
4+
import numpy as np
5+
import torch
6+
from diffusers.utils import load_image
7+
# pip install insightface==0.7.3
8+
from insightface.app import FaceAnalysis
9+
from insightface.data import get_image as ins_get_image
10+
from safetensors.torch import save_file
11+
12+
###
13+
# https://github.com/cubiq/ComfyUI_IPAdapter_plus/issues/165#issue-2055829543
14+
###
15+
class FaceAnalysis2(FaceAnalysis):
16+
# NOTE: allows setting det_size for each detection call.
17+
# the model allows it but the wrapping code from insightface
18+
# doesn't show it, and people end up loading duplicate models
19+
# for different sizes where there is absolutely no need to
20+
def get(self, img, max_num=0, det_size=(640, 640)):
21+
if det_size is not None:
22+
self.det_model.input_size = det_size
23+
24+
return super().get(img, max_num)
25+
26+
def analyze_faces(face_analysis: FaceAnalysis, img_data: np.ndarray, det_size=(640, 640)):
27+
# NOTE: try detect faces, if no faces detected, lower det_size until it does
28+
detection_sizes = [None] + [(size, size) for size in range(640, 256, -64)] + [(256, 256)]
29+
30+
for size in detection_sizes:
31+
faces = face_analysis.get(img_data, det_size=size)
32+
if len(faces) > 0:
33+
return faces
34+
35+
return []
36+
37+
if __name__ == "__main__":
38+
#face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
39+
face_detector = FaceAnalysis2(providers=['CPUExecutionProvider'], allowed_modules=['detection', 'recognition'])
40+
face_detector.prepare(ctx_id=0, det_size=(640, 640))
41+
#input_folder_name = './scarletthead_woman'
42+
input_folder_name = sys.argv[1]
43+
image_basename_list = os.listdir(input_folder_name)
44+
image_path_list = sorted([os.path.join(input_folder_name, basename) for basename in image_basename_list])
45+
46+
input_id_images = []
47+
for image_path in image_path_list:
48+
input_id_images.append(load_image(image_path))
49+
50+
id_embed_list = []
51+
52+
for img in input_id_images:
53+
img = np.array(img)
54+
img = img[:, :, ::-1]
55+
faces = analyze_faces(face_detector, img)
56+
if len(faces) > 0:
57+
id_embed_list.append(torch.from_numpy((faces[0]['embedding'])))
58+
59+
if len(id_embed_list) == 0:
60+
raise ValueError(f"No face detected in input image pool")
61+
62+
id_embeds = torch.stack(id_embed_list)
63+
64+
# for r in id_embeds:
65+
# print(r)
66+
# #torch.save(id_embeds, input_folder_name+'/id_embeds.pt');
67+
# weights = dict()
68+
# weights["id_embeds"] = id_embeds
69+
# save_file(weights, input_folder_name+'/id_embeds.safetensors')
70+
71+
binary_data = id_embeds.numpy().tobytes()
72+
two = 4
73+
zero = 0
74+
one = 1
75+
tensor_name = "id_embeds"
76+
# Write binary data to a file
77+
with open(input_folder_name+'/id_embeds.bin', "wb") as f:
78+
f.write(two.to_bytes(4, byteorder='little'))
79+
f.write((len(tensor_name)).to_bytes(4, byteorder='little'))
80+
f.write(zero.to_bytes(4, byteorder='little'))
81+
f.write((id_embeds.shape[1]).to_bytes(4, byteorder='little'))
82+
f.write((id_embeds.shape[0]).to_bytes(4, byteorder='little'))
83+
f.write(one.to_bytes(4, byteorder='little'))
84+
f.write(one.to_bytes(4, byteorder='little'))
85+
f.write(tensor_name.encode('ascii'))
86+
f.write(binary_data)
87+
88+

ggml_extend.hpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,11 @@ struct GGMLRunner {
10471047
params_buffer_size / (1024.0 * 1024.0),
10481048
ggml_backend_is_cpu(backend) ? "RAM" : "VRAM",
10491049
num_tensors);
1050+
// printf("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)\n",
1051+
// get_desc().c_str(),
1052+
// params_buffer_size / (1024.0 * 1024.0),
1053+
// ggml_backend_is_cpu(backend) ? "RAM" : "VRAM",
1054+
// num_tensors);
10501055
return true;
10511056
}
10521057

@@ -1216,7 +1221,8 @@ class Linear : public UnaryBlock {
12161221
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
12171222
if (bias) {
12181223
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_features);
1219-
}
1224+
}
1225+
12201226
}
12211227

12221228
public:

model.cpp

+40-2
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,33 @@ std::unordered_map<std::string, std::string> vae_decoder_name_map = {
146146
{"first_stage_model.decoder.mid.attn_1.to_v.weight", "first_stage_model.decoder.mid.attn_1.v.weight"},
147147
};
148148

149+
std::unordered_map<std::string, std::string> pmid_v2_name_map = {
150+
{"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight",
151+
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"},
152+
{"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight",
153+
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"},
154+
{"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight",
155+
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"},
156+
{"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight",
157+
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"},
158+
{"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight",
159+
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"},
160+
{"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight",
161+
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"},
162+
{"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight",
163+
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"},
164+
{"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight",
165+
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"},
166+
{"pmid.qformer_perceiver.token_proj.0.bias",
167+
"pmid.qformer_perceiver.token_proj.fc1.bias"},
168+
{"pmid.qformer_perceiver.token_proj.2.bias",
169+
"pmid.qformer_perceiver.token_proj.fc2.bias"},
170+
{"pmid.qformer_perceiver.token_proj.0.weight",
171+
"pmid.qformer_perceiver.token_proj.fc1.weight"},
172+
{"pmid.qformer_perceiver.token_proj.2.weight",
173+
"pmid.qformer_perceiver.token_proj.fc2.weight"},
174+
};
175+
149176
std::string convert_open_clip_to_hf_clip(const std::string& name) {
150177
std::string new_name = name;
151178
std::string prefix;
@@ -212,6 +239,13 @@ std::string convert_vae_decoder_name(const std::string& name) {
212239
return name;
213240
}
214241

242+
std::string convert_pmid_v2_name(const std::string& name) {
243+
if (pmid_v2_name_map.find(name) != pmid_v2_name_map.end()) {
244+
return pmid_v2_name_map[name];
245+
}
246+
return name;
247+
}
248+
215249
/* If not a SDXL LoRA the unet" prefix will have already been replaced by this
216250
* point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */
217251
std::string convert_sdxl_lora_name(std::string tensor_name) {
@@ -443,6 +477,8 @@ std::string convert_tensor_name(std::string name) {
443477
new_name = convert_open_clip_to_hf_clip(name);
444478
} else if (starts_with(name, "first_stage_model.decoder")) {
445479
new_name = convert_vae_decoder_name(name);
480+
} else if (starts_with(name, "pmid.qformer_perceiver")) {
481+
new_name = convert_pmid_v2_name(name);
446482
} else if (starts_with(name, "control_model.")) { // for controlnet pth models
447483
size_t pos = name.find('.');
448484
if (pos != std::string::npos) {
@@ -1015,7 +1051,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10151051
}
10161052

10171053
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
1018-
tensor_storage.reverse_ne();
1054+
tensor_storage.reverse_ne();
10191055

10201056
size_t tensor_data_size = end - begin;
10211057

@@ -1362,7 +1398,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
13621398
reader.tensor_storage.reverse_ne();
13631399
reader.tensor_storage.file_index = file_index;
13641400
// if(strcmp(prefix.c_str(), "scarlett") == 0)
1365-
// printf(" got tensor %s \n ", reader.tensor_storage.name.c_str());
1401+
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
13661402
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
13671403
tensor_storages.push_back(reader.tensor_storage);
13681404
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
@@ -1398,7 +1434,9 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
13981434
std::string name = zip_entry_name(zip);
13991435
size_t pos = name.find("data.pkl");
14001436
if (pos != std::string::npos) {
1437+
14011438
std::string dir = name.substr(0, pos);
1439+
printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str());
14021440
void* pkl_data = NULL;
14031441
size_t pkl_size;
14041442
zip_entry_read(zip, &pkl_data, &pkl_size);

model.h

+6
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ enum SDVersion {
3131
VERSION_COUNT,
3232
};
3333

34+
enum PMVersion {
35+
VERSION_1,
36+
VERSION_2,
37+
};
38+
3439
struct TensorStorage {
3540
std::string name;
3641
ggml_type type = GGML_TYPE_F32;
@@ -162,6 +167,7 @@ class ModelLoader {
162167
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
163168
ggml_backend_t backend,
164169
std::set<std::string> ignore_tensors = {});
170+
165171
bool save_to_gguf_file(const std::string& file_path, ggml_type type);
166172
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
167173
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);

0 commit comments

Comments
 (0)