Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Init NLL_LOSS #269

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open

[Operator] Init NLL_LOSS #269

wants to merge 8 commits into from

Conversation

GwokHiujin
Copy link
Collaborator

A basic implementation of NLL_LOSS has been pushed.

Based on the performance testing results summarized earlier, we believe that using the gather operation would lead to a more efficient implementation (by observing the output results of latency, it seems this is also how torch does it), and we will push forward with this optimization.

@tongxin tongxin self-assigned this Nov 10, 2024
tongxin
tongxin previously approved these changes Nov 15, 2024
Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -135,6 +141,13 @@ def cumsum_input_fn(shape, cur_dtype, device):
FLOAT_DTYPES + INT_DTYPES,
marks=pytest.mark.cumsum,
),
pytest.param(
"nll_loss",
torch.nn.NLLLoss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NLLLoss is a class. Can we use it as the reference function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I've updated it to torch.nn.functional.nll_loss.

BLOCK_N: tl.constexpr,
):
pid_n = tl.program_id(0)
offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check offset_n overflow.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't understand this suggestion. Shouldn't the subsequent mask generation be sufficient to handle potential overflow? Or are you suggesting we check if offsets_n could exceed Triton's maximum representable number? If so, I think many operators will need to incorporate this check too.

):
pid_n = tl.program_id(0)
pid_d = tl.program_id(1)
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overflow check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

):
pid_n = tl.program_id(0)
pid_c = tl.program_id(1)
offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overflow check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


if weight is None:
weight = torch.ones(
[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use tuple

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def forward(ctx, inp, target, weight, reduction, ignore_index):
logging.debug("GEMS NLLLoss FWD")
shape = list(inp.shape)
dim = inp.ndim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input shape/layout appears to be pretty complex as shown in the pytorch doc. Shall we add some documentation here to help clarify the inputs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to add some annotations here about nllloss's parameters.

However, I think users who will use this function should also be quite clear about its principles; the part that looks complicated is just a high-dimensional input 😄

):
pid_n = tl.program_id(0)
pid_d = tl.program_id(1)
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overflow check.

tl.store(inp_grad_ptrs, inp_grad.to(tl.float32), mask=(inp_mask & ignore_mask))


class NLLLoss(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is intended be used as substitute for nll_loss whereas NLLLoss is already taken as the nn module name. We should avoid the name confusion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I've upated the class name.

if reduction == 0:
res = out.to(inp.dtype)
elif reduction == 1:
ctx.total_weight = sum(w_tgt).item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also add dim= args to avoid confusion?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants