23
23
#include < algorithm>
24
24
#include < array>
25
25
#include < cfloat>
26
+ #include < cinttypes>
26
27
#include < cstdint>
28
+ #include < cstdio>
29
+ #include < cstdlib>
27
30
#include < cstring>
28
- #include < cinttypes >
31
+ #include < future >
29
32
#include < memory>
30
33
#include < random>
31
- #include < stdio.h>
32
- #include < stdlib.h>
34
+ #include < regex>
33
35
#include < string>
34
36
#include < thread>
35
- #include < future>
36
37
#include < vector>
37
38
38
39
static void init_tensor_uniform (ggml_tensor * tensor, float min = -1 .0f , float max = 1 .0f ) {
@@ -4382,9 +4383,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4382
4383
return test_cases;
4383
4384
}
4384
4385
4385
- static bool test_backend (ggml_backend_t backend, test_mode mode, const char * op_name) {
4386
+ static bool test_backend (ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter) {
4387
+ auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
4388
+ if (params_filter == nullptr ) {
4389
+ return ;
4390
+ }
4391
+
4392
+ std::regex params_filter_regex (params_filter);
4393
+
4394
+ for (auto it = test_cases.begin (); it != test_cases.end ();) {
4395
+ if (!std::regex_search ((*it)->vars (), params_filter_regex)) {
4396
+ it = test_cases.erase (it);
4397
+ continue ;
4398
+ }
4399
+
4400
+ it++;
4401
+ }
4402
+ };
4403
+
4386
4404
if (mode == MODE_TEST) {
4387
4405
auto test_cases = make_test_cases_eval ();
4406
+ filter_test_cases (test_cases, params_filter);
4388
4407
ggml_backend_t backend_cpu = ggml_backend_init_by_type (GGML_BACKEND_DEVICE_TYPE_CPU, NULL );
4389
4408
if (backend_cpu == NULL ) {
4390
4409
printf (" Failed to initialize CPU backend\n " );
@@ -4406,6 +4425,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4406
4425
4407
4426
if (mode == MODE_GRAD) {
4408
4427
auto test_cases = make_test_cases_eval ();
4428
+ filter_test_cases (test_cases, params_filter);
4409
4429
size_t n_ok = 0 ;
4410
4430
for (auto & test : test_cases) {
4411
4431
if (test->eval_grad (backend, op_name)) {
@@ -4419,6 +4439,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4419
4439
4420
4440
if (mode == MODE_PERF) {
4421
4441
auto test_cases = make_test_cases_perf ();
4442
+ filter_test_cases (test_cases, params_filter);
4422
4443
for (auto & test : test_cases) {
4423
4444
test->eval_perf (backend, op_name);
4424
4445
}
@@ -4429,7 +4450,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
4429
4450
}
4430
4451
4431
4452
static void usage (char ** argv) {
4432
- printf (" Usage: %s [mode] [-o op ] [-b backend]\n " , argv[0 ]);
4453
+ printf (" Usage: %s [mode] [-o <op> ] [-b < backend>] [-p <params regex> ]\n " , argv[0 ]);
4433
4454
printf (" valid modes:\n " );
4434
4455
printf (" - test (default, compare with CPU backend for correctness)\n " );
4435
4456
printf (" - grad (compare gradients from backpropagation with method of finite differences)\n " );
@@ -4439,8 +4460,9 @@ static void usage(char ** argv) {
4439
4460
4440
4461
int main (int argc, char ** argv) {
4441
4462
test_mode mode = MODE_TEST;
4442
- const char * op_name_filter = NULL ;
4443
- const char * backend_filter = NULL ;
4463
+ const char * op_name_filter = nullptr ;
4464
+ const char * backend_filter = nullptr ;
4465
+ const char * params_filter = nullptr ;
4444
4466
4445
4467
for (int i = 1 ; i < argc; i++) {
4446
4468
if (strcmp (argv[i], " test" ) == 0 ) {
@@ -4463,6 +4485,13 @@ int main(int argc, char ** argv) {
4463
4485
usage (argv);
4464
4486
return 1 ;
4465
4487
}
4488
+ } else if (strcmp (argv[i], " -p" ) == 0 ) {
4489
+ if (i + 1 < argc) {
4490
+ params_filter = argv[++i];
4491
+ } else {
4492
+ usage (argv);
4493
+ return 1 ;
4494
+ }
4466
4495
} else {
4467
4496
usage (argv);
4468
4497
return 1 ;
@@ -4509,7 +4538,7 @@ int main(int argc, char ** argv) {
4509
4538
printf (" Device memory: %zu MB (%zu MB free)\n " , total / 1024 / 1024 , free / 1024 / 1024 );
4510
4539
printf (" \n " );
4511
4540
4512
- bool ok = test_backend (backend, mode, op_name_filter);
4541
+ bool ok = test_backend (backend, mode, op_name_filter, params_filter );
4513
4542
4514
4543
printf (" Backend %s: " , ggml_backend_name (backend));
4515
4544
if (ok) {
0 commit comments