-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Init NLL_LOSS #269
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
benchmark/test_reduction_perf.py
Outdated
@@ -135,6 +141,13 @@ def cumsum_input_fn(shape, cur_dtype, device): | |||
FLOAT_DTYPES + INT_DTYPES, | |||
marks=pytest.mark.cumsum, | |||
), | |||
pytest.param( | |||
"nll_loss", | |||
torch.nn.NLLLoss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NLLLoss is a class. Can we use it as the reference function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. I've updated it to torch.nn.functional.nll_loss
.
BLOCK_N: tl.constexpr, | ||
): | ||
pid_n = tl.program_id(0) | ||
offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check offset_n
overflow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I didn't understand this suggestion. Shouldn't the subsequent mask generation be sufficient to handle potential overflow? Or are you suggesting we check if offsets_n
could exceed Triton's maximum representable number? If so, I think many operators will need to incorporate this check too.
): | ||
pid_n = tl.program_id(0) | ||
pid_d = tl.program_id(1) | ||
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overflow check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
): | ||
pid_n = tl.program_id(0) | ||
pid_c = tl.program_id(1) | ||
offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overflow check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
src/flag_gems/ops/nllloss.py
Outdated
|
||
if weight is None: | ||
weight = torch.ones( | ||
[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def forward(ctx, inp, target, weight, reduction, ignore_index): | ||
logging.debug("GEMS NLLLoss FWD") | ||
shape = list(inp.shape) | ||
dim = inp.ndim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input shape/layout appears to be pretty complex as shown in the pytorch doc. Shall we add some documentation here to help clarify the inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried to add some annotations here about nllloss's parameters.
However, I think users who will use this function should also be quite clear about its principles; the part that looks complicated is just a high-dimensional input 😄
): | ||
pid_n = tl.program_id(0) | ||
pid_d = tl.program_id(1) | ||
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overflow check.
src/flag_gems/ops/nllloss.py
Outdated
tl.store(inp_grad_ptrs, inp_grad.to(tl.float32), mask=(inp_mask & ignore_mask)) | ||
|
||
|
||
class NLLLoss(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is intended be used as substitute for nll_loss
whereas NLLLoss is already taken as the nn module name. We should avoid the name confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. I've upated the class name.
if reduction == 0: | ||
res = out.to(inp.dtype) | ||
elif reduction == 1: | ||
ctx.total_weight = sum(w_tgt).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also add dim=
args to avoid confusion?
A basic implementation of
NLL_LOSS
has been pushed.Based on the performance testing results summarized earlier, we believe that using the
gather
operation would lead to a more efficient implementation (by observing the output results of latency, it seems this is also how torch does it), and we will push forward with this optimization.