Skip to content

finetune.cpp command-line arg #13873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

graehl
Copy link

@graehl graehl commented May 28, 2025

add to ggml-opt learning rate (adamw alpha) cmdline arg, and an optimizer enum defaulting to adamw,
preparatory to work to support SGD

these are in common args a set of optimizer options active only for the new FINETUNE example (which includes all the previous finetune.cpp PERPLEXITY options as a precaution)

perhaps breaking with precedent, the ggml_opt_optimizer_params struct is included directly as args - if desired, we can instead just add learning rate and optimizer type to a struct independent of ggml-opt.h

as proposed in
#13835

@graehl graehl requested a review from JohannesGaessler as a code owner May 28, 2025 20:26
@github-actions github-actions bot added examples ggml changes relating to the ggml tensor library for machine learning labels May 28, 2025
@graehl
Copy link
Author

graehl commented May 28, 2025

perhaps no need to review until i have an actual SGD impl in a follow-on, @JohannesGaessler - but a few general questions about contributing:

  1. is it ok to make small retouches to ggml/ sources in this (llama.cpp) project with the expectation of getting the changes into the actual ggml repo later? are there any plans to submodule a ggml-in-llama branch to keep things straight(er)?
  2. is what i've got hee the expected way to add example-specific command line arguments? for finetune we definitely at least want to be able to vary the learning rate, which was formerly hard-coded.
  3. were the PERPLEXITY args which i blindly added to the new FINETUNE example actually doing anything interesting? perhaps some should be dropped from finetune.
  4. could you direct me to a .clang-format style file that might save me from accidentally re-indenting? i know i can set up clang-format to operate only on regions i've already changed ...

Copy link
Contributor

@WilliamTambellini WilliamTambellini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should better keep that change as it time to get more feedbacks/approval.

@JohannesGaessler
Copy link
Collaborator

is it ok to make small retouches to ggml/ sources in this (llama.cpp) project with the expectation of getting the changes into the actual ggml repo later? are there any plans to submodule a ggml-in-llama branch to keep things straight(er)?

Any changes made to the ggml source in this repository will eventually be synced to the ggml repository and vice versa; it is completely fine. I think the issue of a git submodule was previously brought up and rejected.

is what i've got hee the expected way to add example-specific command line arguments? for finetune we definitely at least want to be able to vary the learning rate, which was formerly hard-coded.

My opinion is that people serious about training should be writing a program rather than use a command line tool. Still, I think it's good to make things such as the learning rate configurable in the provided example program.

were the PERPLEXITY args which i blindly added to the new FINETUNE example actually doing anything interesting? perhaps some should be dropped from finetune.

I don't remember whether those args were put in by me when I copypasted code or by Georgi when he later refactored it but I myself definitely did not make an intentional choice to use these exact arguments.

could you direct me to a .clang-format style file that might save me from accidentally re-indenting? i know i can set up clang-format to operate only on regions i've already changed ...

I don't know, sorry.

@WilliamTambellini
Copy link
Contributor

@ggerganov

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the previous perplexity-specific arguments are needed.

@JohannesGaessler
Copy link
Collaborator

For adding an SDG optimizer, add a new ggml op like OPT_STEP_SDG. Add a CPU implementation as a fallback for any backend without an implementation. Add a CUDA implementation since that is (I assume) the backend which you intend to use in production. Add a test to tests/test_backend_ops.cpp to assert that the CPU and CUDA backends produce consistent results. Extend ggml-opt.cpp to conditionally use the new SDG optimizer step, condition the allocation of the optimizer momenta on the optimizer type.

@graehl
Copy link
Author

graehl commented May 29, 2025

For adding an SDG optimizer, add a new ggml op like OPT_STEP_SDG. Add a CPU implementation as a fallback for any backend without an implementation. Add a CUDA implementation since that is (I assume) the backend which you intend to use in production. Add a test to tests/test_backend_ops.cpp to assert that the CPU and CUDA backends produce consistent results. Extend ggml-opt.cpp to conditionally use the new SDG optimizer step, condition the allocation of the optimizer momenta on the optimizer type.

yes, will do. should the actual SGD impl be a subsequent pull req (or several, e.g. starting first w/ just CPU impl) or do you want it all in one pull req?

@JohannesGaessler
Copy link
Collaborator

Either way would be fine with me as long as there are at no point broken or unfinished features on master.

@graehl graehl force-pushed the finelayer branch 2 times, most recently from e752031 to e689af8 Compare May 29, 2025 17:07
Copy link
Contributor

@matiaslin matiaslin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking forward to the next PR(s).

@graehl
Copy link
Author

graehl commented May 29, 2025

you should see frivolous clang-format changes (using the project's .clang-format) only on lines changed in the PR (using git-clang-format). if there's something undesireable we could figure out what in the format config does it

@JohannesGaessler
Copy link
Collaborator

Don't autoformat code en masse unless it's done in a dedicated PR, it makes it unnecessarily difficult to track what was actually changed in a PR.

@JohannesGaessler
Copy link
Collaborator

Sorry, I didn't read the

only on lines changed in the PR

part.

@github-actions github-actions bot added build Compilation issues testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs labels May 30, 2025
@graehl graehl force-pushed the finelayer branch 3 times, most recently from 7534bbf to 48a16bf Compare May 30, 2025 16:57
@graehl
Copy link
Author

graehl commented May 30, 2025

Hi @WilliamTambellini @JohannesGaessler I think this is usable now, inviting code nitpicks etc :)
pretty new to the github interface honestly so let me know if this needs to be two separate PRs one for each commit or if it's reasonable to just review both commits here (obv. better to merge separately, first doesn't break any behavior, second impacts the finetune cmdline default learning rate but that should hurt no one)

@graehl
Copy link
Author

graehl commented May 30, 2025

Second (actual usable SGD) commit is 48a16bf (also shows above here)

Copy link
Contributor

@WilliamTambellini WilliamTambellini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mix up different projects: change of CLI/renaming and SGD. Need to split in 2 PRs.
@slaren ?

@@ -770,7 +814,7 @@ void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
// beta1, beta2 after applying warmup
const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));

const float keep = 1.0f - opt_pars.adamw.alpha * opt_pars.adamw.wd;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimizer steps are going to be I/O bound and optimizing compute is not going to make a meaningful difference for the runtime of the steps, for the runtime of the total probram it's completely negligible. So please revert this change, I think the other variant is easier to understand.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it's not likely to matter, but it's 1. per parameter per epoch (ok, does seem unimportant now that I think further) and 2. i'm not confident cuda CC optimizes this and was hoping to learn more - would seem possible that w/o this we're loading repeatedly two floats instead of one - and mostly 3. this is exactly following precedent established for beta1h and beta2h, which are stored in the tensor just as i stored this quantity.

Anyway, totally willing, just curious what you think about the existing practice of saving beta1h and beta2h in light of this opinion that we're not compute bound.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i checked it out - doesn't seem to change runtime noticeably as you predicted

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My biggest concern with the code is the amount of effort needed to maintain it, particularly when it comes to debugging and asserting that the code on master works correctly. It is quite likely that I will at some point be in a situation where a user reports bad training results and I will not know whether that is the due to a bug in ggml or due to bad hyperparamters or something similar. So it is very important to me that the data layout is consistent across multiple levels.

The correct way to implement the micro-optimization of pre-computing a parameter derived from the human-interpretable parameters is as follows:

  1. Pass the human-interpretable parameters to ggml_opt_step_adamw / ggml_opt_step_sdg.
  2. In the CUDA host code, pre-compute some derived parameters from the human-interpretable parameters.
  3. Change the CUDA device code to accept the derived parameters instead.

The way CUDA works is that the CPU schedules the GPU kernels in a CUDA stream and then waits for said stream to finish all kernels. Scheduling the kernels is of course much faster and it doesn't matter how fast you are as long as you are fast enough to keep the GPU busy. So adding a bit of overhead to the scheduling has essentially no impact on the runtime of a CUDA program even if you do it once per CUDA kernel launch instead of once per epoch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining all that, the bottom line for me is that you were right and the micro-optimization has no visible benefit in this case.

@github-actions github-actions bot added the Vulkan Issues specific to the Vulkan backend label Jul 3, 2025
@graehl
Copy link
Author

graehl commented Jul 3, 2025

fixed: another unreachable break, windows strcasecmp, skip SGD test-opt case (keep ADAMW) for vulkan. should (nearly) pass CI now

@lexasub
Copy link
Contributor

lexasub commented Jul 3, 2025

@graehl need fix windows builds))

@lexasub
Copy link
Contributor

lexasub commented Jul 3, 2025

D:\a\llama.cpp\llama.cpp\common\common.cpp(1571,39): error: dllimport cannot be applied to non-inline function definition
1571 | GGML_API enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {

@lexasub
Copy link
Contributor

lexasub commented Jul 3, 2025

@graehl you may create new mr(from other branch) which willn't break windows build, then mr will small) (if fix windows is very hard)
also i create example for loading original finetune.cpp by parquet #14522

@graehl
Copy link
Author

graehl commented Jul 3, 2025

fixed windows dll GGML_API build
fixed test-backend-ops (avoid SGD * vulkan, still TODO)

@lexasub
Copy link
Contributor

lexasub commented Jul 3, 2025

image
graehl:finelayer rebased on master (finetune.cpp)

@graehl
Copy link
Author

graehl commented Jul 4, 2025

image graehl:finelayer rebased on master (finetune.cpp)

(I reverted the unintentional diffs shown)

@lexasub
Copy link
Contributor

lexasub commented Jul 8, 2025

@graehl fail on build on some windows cases

@graehl
Copy link
Author

graehl commented Jul 8, 2025

@graehl fail on build on some windows cases

I'll see if I can figure it out

@graehl
Copy link
Author

graehl commented Jul 8, 2025

seems test-opt can't link on windows (test-opt.obj : error LNK2019: unresolved external symbol ggml_backend_is_cpu referenced in function main [D:\a\llama.cpp\llama.cpp\build\tests\test-opt.vcxproj]
test-opt.obj : error LNK2019: unresolved external symbol ggml_backend_cpu_set_n_threads referenced in function main [D:\a\llama.cpp\llama.cpp\build\tests\test-opt.vcxproj]
D:\a\llama.cpp\llama.cpp\build\bin\Release\test-opt.exe : fatal error LNK1120: 2 unresolved externals [D:\a\llama.cpp\llama.cpp\build\tests\test-opt.vcxproj])
might need to resolve by re-disabling the test on windows; any suggestion to fix otherwise?

@graehl
Copy link
Author

graehl commented Jul 8, 2025

do we care about the vulkan test-backend-ops timing out? seems nothing is actually wrong ...
31: FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[1,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]): OK
31: FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[1,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]): OK
31/36 Test #31: test-backend-ops ..................***Timeout 3600.07 sec
test 34

@JohannesGaessler
Copy link
Collaborator

The Vulkan issue should be fixed with #14574 .

@graehl
Copy link
Author

graehl commented Jul 9, 2025

I believe the windows link issue is fixed by virtue of removing the code in test-opt on windows that attempts to set the # of threads for the cpu backend. Can we see QA?

common/common.h Outdated
struct lr_opt lr;
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
float val_split = 0.05f; // fraction of the data used for the validation set
std::string opt_save_model_to = "finetuned-model.gguf";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::string opt_save_model_to = "finetuned-model.gguf";

I think you forgot to remove this?

Comment on lines 338 to 339
if (0)
GGML_LOG_DEBUG("%s static=%d accumulate=%d opt_period=%d optimizer=%d\n", __func__, (int32_t)opt_ctx->static_graphs, (int32_t)accumulate, (int32_t)opt_ctx->opt_period, (int32_t)optimizer);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (0)
GGML_LOG_DEBUG("%s static=%d accumulate=%d opt_period=%d optimizer=%d\n", __func__, (int32_t)opt_ctx->static_graphs, (int32_t)accumulate, (int32_t)opt_ctx->opt_period, (int32_t)optimizer);

Forgot to remove this?

GGML_ASSERT(result->opt_period >= 1);

result->static_graphs = result->ctx_compute;
GGML_LOG_DEBUG("%s opt_period=%d\n", __func__, (int32_t)result->opt_period);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
GGML_LOG_DEBUG("%s opt_period=%d\n", __func__, (int32_t)result->opt_period);

Forgot to remove this?

GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
float * sgd = ggml_get_data_f32(opt_ctx->adamw_params);
sgd[1] = 1. - (sgd[0] = opt_pars.sgd.alpha) * opt_pars.sgd.wd;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inconsistent with the docstring in ggml.h. As I outlined before for AdamW, the interface in ggml.h should be using the human-readable parameters. Please simply pass alpha and wd here. A derived parameter keep should be calculated in the backend-specific implementations for OPT_STEP_SGD.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fine by me but i'm holding off for now

Comment on lines 10548 to 10645
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return true;
case GGML_OP_OPT_STEP_ADAMW:
return true.
case GGML_OP_OPT_STEP_SGD:
return false;

There is no working Vulkan implementation for OPT_STEP_SGD so this function should return false, the CPU backend will then be used as a fallback. It is not necessary to make any further changes for the Vulkan backend.

Comment on lines 5592 to 5597
char const* name = ggml_backend_name(backend);
bool const vulkan = strstr(name, "ulkan");
bool const sgd = !vulkan;

if (mode == MODE_TEST) {
auto test_cases = make_test_cases_eval();
auto test_cases = make_test_cases_eval(sgd);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this logic if the Vulkan backend simply returns that OPT_STEP_SGD is unsupported, the corresponding test will then simply be skipped.

helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass);
const bool subtest_ok = weights == -ndata * .5;
TEST_LOG("%s: ndata=%d weights=%f\n", __func__, (int) ndata, (double) weights);
assert(subtest_ok);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(subtest_ok);

Forgot to remove this?

@@ -417,75 +450,76 @@ static std::pair<int, int> test_forward_backward(
double loss_unc;
ggml_opt_result_loss(cd.result, &loss, &loss_unc);
subtest_ok = subtest_ok && loss == 18.0 && (shuffle || loss_unc == 0.0);
assert(subtest_ok);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(subtest_ok);


double accuracy;
double accuracy_unc;
ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
assert(subtest_ok);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(subtest_ok);

Comment on lines 931 to 933
if (optim == GGML_OPT_OPTIMIZER_TYPE_SGD && !strcmp(devname, "Vulkan0"))
// TODO
continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (optim == GGML_OPT_OPTIMIZER_TYPE_SGD && !strcmp(devname, "Vulkan0"))
// TODO
continue;

@graehl
Copy link
Author

graehl commented Jul 14, 2025

The mac CI test seems unrelated (timeout - just increase it above 900s on this platform?) https://github.com/ggml-org/llama.cpp/actions/runs/16152514813/job/45670340308?pr=13873 27 - test-thread-safety (Timeout) main
27: 14.14.041.515 I Model 1/2, Context 2/4: The meaning of life is on a cold day. Jane and her mom are going to the store. Jane likes the store, but she has a problem.
27: "Mom, why are you so upset?" Jane's mom asked her.
27: "I'm upset because I'm out of sugar," Jane replied.
27: "Oh Jane, why don't we make some sugar? We will do it together," her mom suggested.
27: Jane was excited. She was ready to make sugar with her mom. But, before she started, her mom said, "Jane, why don't you use a jar of sugar
27:
27/36 Test #27: test-thread-safety ................***Timeout 900.04 sec

@graehl
Copy link
Author

graehl commented Jul 15, 2025

I think we should skip the SGD part of test-opt on vulkan as I had it (I do not know how to fix this; I don't have a local vulkan backend machine):
2025-07-14T21:47:15.4378307Z 29: /home/runner/work/llama.cpp/llama.cpp/ggml/src/ggml-backend.cpp:750: pre-allocated tensor (SGD step for weights) in a buffer (Vulkan0) that cannot run the operation (GLU) ... 2025-07-14T21:47:15.4383508Z 29: /home/runner/work/llama.cpp/llama.cpp/build/bin/test-opt(+0x8a09)[0x559e7c319a09] 2025-07-14T21:47:15.4383987Z 29: /home/runner/work/llama.cpp/llama.cpp/build/bin/test-opt(+0x498e)[0x559e7c31598e] ... 2025-07-14T21:47:16.2324353Z 29/36 Test #29: test-opt ..........................Subprocess aborted***Exception: 1.41 sec

@graehl graehl force-pushed the finelayer branch 2 times, most recently from 56ca1c4 to 61cc635 Compare July 15, 2025 17:38
@graehl
Copy link
Author

graehl commented Jul 15, 2025

I took some guesses about disabling allocation in vulkan/SGD. perhaps they will work if you run CI

@graehl
Copy link
Author

graehl commented Jul 15, 2025

also presumably you're aware of this but the 'build-linux-cross' vulkan CI machines are misconfigured (at least intermittently)
image

add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating
m, v tensors.

support finetune.cpp arg -opt SGD (or sgd). (default adamw as before)

llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch)
when using SGD instead of 19gb (55 sec/epoch) using adamw.
(wikipedia 100 lines finetune)

(
using the same GPU memory, adamw can only do before OOM 512
batch/context, reaching:
train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00
val:   [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00

SGD is superior, though it converges slower, with max before OOM 1728
batch/context (esp see the better validation perf):
train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00
val:   [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00
)

note: when finetuning long enough (or w/ enough -lr),
validation accuracy *eventually* drops ('catastrophic forgetting')

-lr-half (halflife) option useful for SGD to avoid oscillation or
super slow underdamped learning (makes setting -lr more forgiving).
terminal -lr for now is set by lr-halvings i.e. if you want at most
1/8 the inital -lr you set -lr-halvings 3.

note: objective loss not directly comparable between adamw, sgd? -
check perplexity or accuracy or consider relative improvements
for convergence

new finetune args -wd 1e-9 to enable weight decay in sgd or adamw,
and max -epochs N (default 2 as before)

cache (1 - wd*alpha) in 'adamw' opt struct -
no noticeable perf benefit, disabled (still done
for new SGD though)

since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params
would probably be able to change between SGD and AdamW with each epoch
but would need to use adamw for the first (unconfirmed - no cmdline arg
to set such a policy yet)

test-opt checks adamw as before and now sgd (except for a few disabled
tests for sgd only; probably just needs logging values and adding
alternate reference values);  tolerance on the 'regression'
test is broader for sgd (so we don't need many more epochs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build Compilation issues examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants