diff --git a/common/common.cpp b/common/common.cpp index 5f70b2a6da049..24cf0363119d5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1636,6 +1636,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cvector_negative_file = argv[i]; return true; } + if (arg == "--single-prompt") { + params.single_prompt = true; + return true; + } if (arg == "--completions") { if (++i >= argc) { invalid_param = true; @@ -1985,6 +1989,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "cvector", "-o, --output FNAME", "output file (default: '%s')", params.cvector_outfile.c_str() }); options.push_back({ "cvector", "--positive-file FNAME", "positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str() }); options.push_back({ "cvector", "--negative-file FNAME", "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() }); + options.push_back({ "cvector", "--single-prompt", "assume prompt files only contain one prompt (for multiline prompts)" }); options.push_back({ "cvector", "--completions-file FNAME","completions file (default: '%s')", params.cvector_completions_file.c_str() }); options.push_back({ "cvector", "--completions N", "number of lines of completions file to use (default: %d)", params.n_completions }); options.push_back({ "cvector", "--batch-pca N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch }); diff --git a/common/common.h b/common/common.h index fa44859c6a63f..50513f2c19877 100644 --- a/common/common.h +++ b/common/common.h @@ -241,6 +241,7 @@ struct gpt_params { std::string cvector_completions_file = "examples/control-vector-generator/completions.txt"; std::string cvector_positive_file = "examples/control-vector-generator/positive.txt"; std::string cvector_negative_file = "examples/control-vector-generator/negative.txt"; + bool single_prompt = false; }; void gpt_params_handle_model_default(gpt_params & params); diff --git a/examples/control-vector-generator/control-vector-generator.cpp b/examples/control-vector-generator/control-vector-generator.cpp index 136f78974ca97..d3c29cc948172 100644 --- a/examples/control-vector-generator/control-vector-generator.cpp +++ b/examples/control-vector-generator/control-vector-generator.cpp @@ -289,7 +289,7 @@ static std::string to_string(const T & val) { return ss.str(); } -static std::vector ctrlvec_load_prompt_file(std::string path, bool skip_empty_lines = false) { +static std::vector ctrlvec_load_prompt_file(std::string path, bool skip_empty_lines = false, bool single_prompt = false) { std::vector output; std::ifstream file(path); if (!file.is_open()) { @@ -304,6 +304,14 @@ static std::vector ctrlvec_load_prompt_file(std::string path, bool } } file.close(); + if (single_prompt) { + std::string single_prompt; + for (const auto & line : output) { + single_prompt += line + "\n"; + } + output.clear(); + output.push_back(single_prompt); + } return output; } @@ -362,8 +370,8 @@ static void export_gguf(const std::vector & v_ctrl, const */ static int prepare_entries(gpt_params & params, train_context & ctx_train) { // load prompts - std::vector positive_prompts = ctrlvec_load_prompt_file(params.cvector_positive_file); - std::vector negative_prompts = ctrlvec_load_prompt_file(params.cvector_negative_file); + std::vector positive_prompts = ctrlvec_load_prompt_file(params.cvector_positive_file, false, params.single_prompt); + std::vector negative_prompts = ctrlvec_load_prompt_file(params.cvector_negative_file, false, params.single_prompt); if (positive_prompts.size() != negative_prompts.size()) { fprintf(stderr, "number of positive and negative prompts must be equal\n"); return 1; diff --git a/examples/control-vector-generator/pca.hpp b/examples/control-vector-generator/pca.hpp index 903f6cc7c9ec7..28f8bd3e9f4ca 100644 --- a/examples/control-vector-generator/pca.hpp +++ b/examples/control-vector-generator/pca.hpp @@ -11,6 +11,7 @@ #endif #include +#include #include #include #include