Skip to content

Commit d5c63cd

Browse files
authored
test-backend-ops : add option -p to filter by op params (#12155)
1 parent 9660ffe commit d5c63cd

File tree

1 file changed

+38
-9
lines changed

1 file changed

+38
-9
lines changed

tests/test-backend-ops.cpp

+38-9
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@
2323
#include <algorithm>
2424
#include <array>
2525
#include <cfloat>
26+
#include <cinttypes>
2627
#include <cstdint>
28+
#include <cstdio>
29+
#include <cstdlib>
2730
#include <cstring>
28-
#include <cinttypes>
31+
#include <future>
2932
#include <memory>
3033
#include <random>
31-
#include <stdio.h>
32-
#include <stdlib.h>
34+
#include <regex>
3335
#include <string>
3436
#include <thread>
35-
#include <future>
3637
#include <vector>
3738

3839
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
@@ -4382,9 +4383,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
43824383
return test_cases;
43834384
}
43844385

4385-
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
4386+
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter) {
4387+
auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
4388+
if (params_filter == nullptr) {
4389+
return;
4390+
}
4391+
4392+
std::regex params_filter_regex(params_filter);
4393+
4394+
for (auto it = test_cases.begin(); it != test_cases.end();) {
4395+
if (!std::regex_search((*it)->vars(), params_filter_regex)) {
4396+
it = test_cases.erase(it);
4397+
continue;
4398+
}
4399+
4400+
it++;
4401+
}
4402+
};
4403+
43864404
if (mode == MODE_TEST) {
43874405
auto test_cases = make_test_cases_eval();
4406+
filter_test_cases(test_cases, params_filter);
43884407
ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
43894408
if (backend_cpu == NULL) {
43904409
printf(" Failed to initialize CPU backend\n");
@@ -4406,6 +4425,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
44064425

44074426
if (mode == MODE_GRAD) {
44084427
auto test_cases = make_test_cases_eval();
4428+
filter_test_cases(test_cases, params_filter);
44094429
size_t n_ok = 0;
44104430
for (auto & test : test_cases) {
44114431
if (test->eval_grad(backend, op_name)) {
@@ -4419,6 +4439,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
44194439

44204440
if (mode == MODE_PERF) {
44214441
auto test_cases = make_test_cases_perf();
4442+
filter_test_cases(test_cases, params_filter);
44224443
for (auto & test : test_cases) {
44234444
test->eval_perf(backend, op_name);
44244445
}
@@ -4429,7 +4450,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
44294450
}
44304451

44314452
static void usage(char ** argv) {
4432-
printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
4453+
printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>]\n", argv[0]);
44334454
printf(" valid modes:\n");
44344455
printf(" - test (default, compare with CPU backend for correctness)\n");
44354456
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
@@ -4439,8 +4460,9 @@ static void usage(char ** argv) {
44394460

44404461
int main(int argc, char ** argv) {
44414462
test_mode mode = MODE_TEST;
4442-
const char * op_name_filter = NULL;
4443-
const char * backend_filter = NULL;
4463+
const char * op_name_filter = nullptr;
4464+
const char * backend_filter = nullptr;
4465+
const char * params_filter = nullptr;
44444466

44454467
for (int i = 1; i < argc; i++) {
44464468
if (strcmp(argv[i], "test") == 0) {
@@ -4463,6 +4485,13 @@ int main(int argc, char ** argv) {
44634485
usage(argv);
44644486
return 1;
44654487
}
4488+
} else if (strcmp(argv[i], "-p") == 0) {
4489+
if (i + 1 < argc) {
4490+
params_filter = argv[++i];
4491+
} else {
4492+
usage(argv);
4493+
return 1;
4494+
}
44664495
} else {
44674496
usage(argv);
44684497
return 1;
@@ -4509,7 +4538,7 @@ int main(int argc, char ** argv) {
45094538
printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
45104539
printf("\n");
45114540

4512-
bool ok = test_backend(backend, mode, op_name_filter);
4541+
bool ok = test_backend(backend, mode, op_name_filter, params_filter);
45134542

45144543
printf(" Backend %s: ", ggml_backend_name(backend));
45154544
if (ok) {

0 commit comments

Comments
 (0)