Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions inference/models/mixtral.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
Tensor aggregate_inputs[4 + mixtral_config.num_local_experts] = {nullptr};
for (int expert_idx = 0; expert_idx < mixtral_config.num_local_experts;
expert_idx++) {
grouped_tokens[expert_idx] = ff_norm; // TODO this is a dirty fix. Restore using group_by!
// grouped_tokens[expert_idx] = ff_norm; // TODO this is a dirty fix. Restore using group_by!
Tensor w1 = ff.dense(grouped_tokens[expert_idx], // (hidden_size, 1, result of calc in groupby)
mixtral_config.intermediate_size,
AC_MODE_NONE,
Expand Down Expand Up @@ -336,7 +336,7 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,

aggregate_inputs[0] = topk_values;
aggregate_inputs[1] = topk_indices;
aggregate_inputs[2] = topk_values;
aggregate_inputs[2] = topk_values; // TODO Causes Legion runtime error!!
aggregate_inputs[3] = gate;
mlp_out = aggregate_inputs[5]; // TODO don't use only one expert
// mlp_out = ff.aggregate(aggregate_inputs,
Expand Down
15 changes: 15 additions & 0 deletions src/ops/aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ OpMeta *Aggregate::init_task(Task const *task,
// ... including some steps with GenericTensorAccessorR
// Shoud I include?

// Only needed to allocate memroy in the kernel
AggregateMeta *m = new AggregateMeta(handle, agg, gpu_mem_allocator);
for (int i = 0; i < 10; i++) { // TODO 10 is a magic number
m->input_type[i] = agg->inputs[i]->data_type;
Expand Down Expand Up @@ -489,6 +490,18 @@ void Aggregate::forward_task(Task const *task,
ctx, task->regions[total_input_cnt].region.get_index_space());


Aggregate::forward_kernel_wrapper(m,
bc,
exp_preds,
acc_gate_assign.ptr(rect_gate_assign),
acc_gate_pred.ptr(rect_gate_pred),
acc_output.ptr(rect_output),
n,
k,
rows,
batch_size,
out_dim);

// TODO One of those three linese cause the mismatch error
// get gate_pred, gate_assign, output
//AccessorRO<float, 3> const acc_gate_pred(regions[0], FID_DATA); // This one alone does cause the problem
Expand Down Expand Up @@ -534,6 +547,7 @@ void Aggregate::forward_task(Task const *task,

// printf("CALLING FOWARD_KERNEL_WRAPPER IN FORWARD_TASK\n");

// From ZJ: we lose shape of tensors when we do this approach. Appraoch in sigmoid silu is recommended
// Aggregate::forward_kernel_wrapper(m,
// exp_preds,
// acc_gate_assign.ptr(rect_gate_assign),
Expand Down Expand Up @@ -604,6 +618,7 @@ void Aggregate::inference_task(Task const *task,

// TODO should we have an inference_kernel wrapper?
Aggregate::forward_kernel_wrapper(m,
bc,
exp_preds,
acc_gate_assign.ptr(rect_gate_assign),
acc_gate_pred.ptr(rect_gate_pred),
Expand Down
2 changes: 2 additions & 0 deletions src/ops/aggregate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ __global__ void agg_backward_kernel(float **exp_preds,

/*static*/
void Aggregate::forward_kernel_wrapper(AggregateMeta const *m,
BatchConfig const *bc,
float **exp_preds,
int const *acc_gate_assign_ptr,
float const *acc_gate_pred_ptr,
Expand Down Expand Up @@ -307,6 +308,7 @@ void Aggregate::backward_kernel_wrapper(AggregateMeta const *m,
}
}

// Only needed if we allocate memory , hwihci s not our case
AggregateMeta::AggregateMeta(FFHandler handler,
Aggregate const *aggr,
MemoryAllocator &gpu_mem_allocator)
Expand Down
4 changes: 3 additions & 1 deletion src/ops/inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,8 @@ OpMeta *IncMultiHeadSelfAttention::init_task(
// printf("running IncMultiHeadSelfAttention::init_task\n");
IncMultiHeadSelfAttention const *attn = (IncMultiHeadSelfAttention *)task->args;
FFHandler handle = *((FFHandler const *)task->local_args);

// We call the below to get the shape info, so we can do the assertions
// I also shouldnt care about offloading
GenericTensorAccessorR input =
helperGetGenericTensorAccessorRO(attn->inputs[0]->data_type,
regions[0],
Expand Down Expand Up @@ -745,6 +746,7 @@ bool IncMultiHeadSelfAttention::get_int_parameter(PMParameter para,
}
}

// Just for benchmarking, don't need that
bool IncMultiHeadSelfAttention::measure_operator_cost(
Simulator *sim, MachineView const &mv, CostMetrics &cost_metrics) const {
return false;
Expand Down
3 changes: 2 additions & 1 deletion src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m,
int const per_head_size = m->qProjSize;
float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f;
size_t smem_sz;
if (per_head_size == 32) {
if (per_head_size == 32) { // ok to do that
constexpr int THREADS_PER_VALUE_32 = threads_per_value_t<DT, 32>::value;
LAUNCH_ATTENTION_SCORE_KERNEL(
DT, 32, 32, 4, THREADS_PER_VALUE_32, 128, stream);
Expand Down Expand Up @@ -1517,6 +1517,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper(
assert(input.data_type == output.data_type);

if (input.data_type == DT_HALF) {
// calling input.get_inc_ptr() below would cause a legion error type mismatch get index space doamine
Kernels::IncMultiHeadAttention::inference_kernel(
m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream);
} else if (input.data_type == DT_FLOAT) {
Expand Down
2 changes: 1 addition & 1 deletion src/ops/sigmoid_silu_multi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ void SigmoidSiluMulti::inference_kernel_wrapper(
min(CUDA_NUM_THREADS, num_elements),
0,
stream>>>(input1.domain.get_volume(),
input1.get_float_ptr(),
input1.get_float_ptr(), // Ultimately we get pointers,whereas in mixtarl branch we pass pointers to this func.
input2.get_float_ptr(),
output.get_float_ptr());
} else if (m->input_type[0] == DT_HALF) {
Expand Down