Skip to content

Commit e61bb0a

Browse files
committed
[feat] test loss func & assert close
1 parent 265d430 commit e61bb0a

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import copy
2+
3+
import torch
4+
from coati.distributed.loss import PolicyLoss
5+
from torch.testing import assert_close
6+
7+
from colossalai.testing import parameterize
8+
from colossalai.utils import set_seed
9+
10+
11+
@parameterize(
12+
"test_config",
13+
[
14+
{
15+
"precision": torch.bfloat16,
16+
"device": "npu",
17+
},
18+
],
19+
)
20+
def run_policy_loss_fn(test_config):
21+
dtype = test_config["precision"]
22+
device = test_config["device"]
23+
set_seed(42)
24+
policy_loss_fn = PolicyLoss()
25+
26+
############
27+
# init npu tensor
28+
############
29+
action_log_probs = torch.rand(8, 2048, dtype=dtype, device=device) # float [8, 2048]
30+
old_action_log_probs = torch.rand(8, 2048, dtype=dtype, device=device) # float [8, 2048]
31+
advantages = torch.rand(8, dtype=dtype, device=device) # float [8]
32+
per_token_kl = torch.rand(8, 2048, dtype=dtype, device=device) # float [8, 2048]
33+
action_mask = torch.randint(
34+
low=0, high=2, size=(8, 2048), dtype=torch.int32, device=device
35+
) # torch.int32 [8, 2048] in range(0,1)
36+
37+
loss, skip_update, _ = policy_loss_fn(
38+
action_log_probs,
39+
old_action_log_probs,
40+
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
41+
per_token_kl,
42+
action_mask,
43+
)
44+
45+
############
46+
# init cpu tensor
47+
############
48+
action_log_probs_cpu = copy.deepcopy(action_log_probs.cpu())
49+
old_action_log_probs_cpu = copy.deepcopy(old_action_log_probs.cpu())
50+
advantages_cpu = copy.deepcopy(advantages.cpu())
51+
per_token_kl_cpu = copy.deepcopy(per_token_kl.cpu())
52+
action_mask_cpu = copy.deepcopy(action_mask.cpu())
53+
54+
loss_cpu, skip_update_cpu, _ = policy_loss_fn(
55+
action_log_probs_cpu,
56+
old_action_log_probs_cpu,
57+
advantages_cpu.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
58+
per_token_kl_cpu,
59+
action_mask_cpu,
60+
)
61+
62+
# assert close
63+
assert_close(
64+
loss.to("cpu"),
65+
loss_cpu,
66+
rtol=5e-4,
67+
atol=5e-4,
68+
# msg=f"NPU/CPU {test_config['precision']} not close"
69+
)
70+
71+
72+
def test_loss_func():
73+
run_policy_loss_fn()
74+
75+
76+
if __name__ == "__main__":
77+
test_loss_func()

0 commit comments

Comments
 (0)