Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic TPU] Use vmask pack if possible for mask's bitwidth change and introduce relayout op. #25938

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def get_dump_file_prefix() -> str:
(
"func.func(tpu-relayout-insertion{"
f" sublane-count={sl_cnt} lane-count={l_cnt}"
f" hardware-generation={hardware_generation}"
"})"
),
]
Expand Down
18 changes: 18 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,21 @@ def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]>
let hasVerifier = 1;
}

def TPU_RelayoutOp : TPU_Op<"relayout", [SameOperandsAndResultType]> {
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
}

def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> {
let arguments = (ins
VectorOfNonZeroRankOf<[I1]>: $low,
VectorOfNonZeroRankOf<[I1]>: $high
);
let results = (outs VectorOfNonZeroRankOf<[I1]>:$output);
let assemblyFormat = [{ $low `,` $high `,` attr-dict `:` type($low) `,` type($high) `->` type($output) }];
}

def TPU_GatherOp : TPU_Op<"gather", [Pure]> {
let arguments = (ins
AnyVectorOfNonZeroRank:$source,
Expand Down Expand Up @@ -891,6 +906,9 @@ def RelayoutInsertionPass : Pass<"tpu-relayout-insertion", "::mlir::func::FuncOp
];
let constructor = "::mlir::tpu::createRelayoutInsertionPass()";
let options = [
// If hardware_generation is not set, the default value of -1 will crash on
// runOnOperation.
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
];
Expand Down
1 change: 1 addition & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
const TpuTilingFlags &tpu_tiling_flags = {});

std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
int hardware_generation = -1,
std::array<int64_t, 2> target_shape = {8, 128});

std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
Expand Down
77 changes: 73 additions & 4 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,74 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op,
return success();
}

LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_EQ_OP(op.getNumOperands(), 1);
TPU_ASSERT_EQ_OP(op.getNumResults(), 1);
TPU_ASSERT_EQ_OP(layouts_in.size(), 1);
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
TPU_ASSERT_OP(layouts_in[0].has_value());
TPU_ASSERT_OP(layouts_out[0].has_value());
const auto& in_layout = *layouts_in[0];
const auto& out_layout = *layouts_out[0];
auto realyout_op = cast<tpu::RelayoutOp>(op);
auto in_bitwidth = in_layout.bitwidth();
auto out_bitwidth = out_layout.bitwidth();
auto vty = cast<VectorType>(realyout_op.getType());
ImplicitLocOpBuilder builder(op.getLoc(), &op);
if (in_layout == out_layout) {
realyout_op.replaceAllUsesWith(realyout_op.getInput());
realyout_op.erase();
return success();
}
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> vals,
disassemble(builder, in_layout,
cast<TypedValue<VectorType>>(realyout_op.getInput()),
ctx.target_shape,
/*use_implicit_shape=*/true));
// Packing vector masks from 32-bit to 16-bit.
if (vty.getElementType() == builder.getI1Type() && in_bitwidth == 32 &&
out_bitwidth == 16 &&
in_layout.tiling()[0] == in_layout.packing() * ctx.target_shape[0] &&
in_layout.tiling()[1] == ctx.target_shape[1] &&
in_layout.tiling() == out_layout.tiling() &&
in_layout.offsets() == out_layout.offsets() &&
in_layout.implicit_dim() == out_layout.implicit_dim()) {
std::vector<int64_t> vmsks_shape(vals.dimensions().begin(),
vals.dimensions().end());
*(vmsks_shape.end() - 1) = llvm::divideCeil(vmsks_shape.back(), 2);
xla::Array<Value> out_vmsks(vmsks_shape, nullptr);
SmallVector<int64_t> val_idx;
Value default_val =
getFullLikeVector(builder, cast<TypedValue<VectorType>>(*vals.begin()),
IntegerAttr::get(builder.getI1Type(), 0));
out_vmsks.Each([&](absl::Span<const int64_t> idx, Value *v) {
val_idx.assign(idx.begin(), idx.end());
// TODO(jevinjiang): can be simplified when offset is replicated.
*(val_idx.end() - 1) *= 2;
Value low_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1)
? vals(val_idx)
: default_val;
*(val_idx.end() - 1) += 1;
Value high_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1)
? vals(val_idx)
: default_val;
const VectorType mask_ty = getNativeVregOrVmaskType(
builder.getI1Type(), in_bitwidth / 2, ctx.target_shape);
*v = builder.create<PackMaskOp>(mask_ty, low_part, high_part);
});
const RollVectorsOp rolled_op =
assemble(builder, vty, out_layout, out_vmsks, ctx.target_shape,
/*use_implicit_shape=*/true);
op.replaceAllUsesWith(rolled_op);
op.erase();
return success();
}
return op.emitOpError("Not implemented: unsupported layout change");
}

// TODO(b/347016737): Deprecate tpu.rotate and only use tpu.dynamic_rotate. So
// we do not need template for the op type and to explicitly force amount
// argument to dynamic.
Expand Down Expand Up @@ -4644,9 +4712,9 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
return success();
}

LogicalResult prng_random_bits_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_EQ_OP(layouts_in.size(), 0);
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
TPU_ASSERT_OP(layouts_out.front().has_value());
Expand Down Expand Up @@ -4711,7 +4779,8 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule},
{tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule},
{tpu::PRNGRandomBitsOp::getOperationName(), tpu_prng_random_bits_rule},
{tpu::RelayoutOp::getOperationName(), tpu_relayout_rule},
{tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule},
{vector::BroadcastOp::getOperationName(), vector_broadcast_rule},
{vector::ExtractOp::getOperationName(), vector_extract_rule},
Expand Down
71 changes: 46 additions & 25 deletions jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,40 @@ namespace {

FailureOr<TypedValue<VectorType>> relayout(
OpBuilder &builder, TypedValue<VectorType> v, VectorLayout src,
VectorLayout dst, const std::array<int64_t, 2> target_shape) {
VectorLayout dst, int hardware_generation,
const std::array<int64_t, 2> target_shape) {
// change bitwidth
if (v.getType().getElementType() == builder.getI1Type() &&
// TODO(jevinjiang): for other relayout changes (tiling, offsets, implicit
// dim), we currently rely on apply-vector-layout pass to do the relayout.
src.bitwidth() != dst.bitwidth()) {
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
auto dst_bitwidth_layout = VectorLayout(
dst.bitwidth(),
{
src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0]
: LayoutOffset(),
src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1]
: LayoutOffset(),
},
src.tiling(), src.implicit_dim());
if (!dst_bitwidth_layout.isValid(target_shape)) {
return emitError(v.getLoc(),
"Not implemented: failed to infer valid layout during "
"relayout, got ")
<< dst_bitwidth_layout;
}
// We might be able to pack mask directly.
// TODO(jevinjiang): Add support for 16bit -> 8bit mask packing.
if (src.bitwidth() == 32 && dst.bitwidth() == 16 &&
// TODO(jevinjiang): support mask packing for non-native source tiling.
src.tiling()[0] == src.packing() * target_shape[0] &&
src.tiling()[1] == target_shape[1]) {
auto relayout_op =
builder.create<tpu::RelayoutOp>(v.getLoc(), v.getType(), v);
setLayout(relayout_op, src, dst_bitwidth_layout);
return cast<TypedValue<VectorType>>(relayout_op.getResult());
}
CHECK(llvm::isPowerOf2_32(src.bitwidth()));
CHECK(llvm::isPowerOf2_32(dst.bitwidth()));
auto make_vty = [&](int bitwidth) {
Expand All @@ -56,25 +84,9 @@ FailureOr<TypedValue<VectorType>> relayout(
};
auto src_int_vty = make_vty(src.bitwidth());
auto dst_int_vty = make_vty(dst.bitwidth());
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
// TODO(jevinjiang): Since dst_bitwidth_layout will be firstly used in the
// extSI or truncI below, we can reuse the inferExt and inferTrunc from
// infer-vector-layout pass.
auto dst_bitwidth_layout = VectorLayout(
dst.bitwidth(),
{
src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0]
: LayoutOffset(),
src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1]
: LayoutOffset(),
},
src.tiling(), src.implicit_dim());
if (!dst_bitwidth_layout.isValid(target_shape)) {
return emitError(v.getLoc(),
"Not implemented: failed to infer valid layout during "
"relayout, got ")
<< dst_bitwidth_layout;
}
auto ext_op = builder.create<arith::ExtUIOp>(v.getLoc(), src_int_vty, v);
setLayout(ext_op, src, src);

Expand All @@ -98,7 +110,7 @@ FailureOr<TypedValue<VectorType>> relayout(

// TODO(jevinjiang): make relayout to an op so we don't need decide when to
// relayout in apply-vector-layout pass.
LogicalResult insertRelayout(Operation &op,
LogicalResult insertRelayout(Operation &op, int hardware_generation,
const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layouts,
getInLayouts(op, target_shape));
Expand Down Expand Up @@ -136,24 +148,32 @@ LogicalResult insertRelayout(Operation &op,
continue;
}
OpBuilder builder(&op);
FAILUREOR_ASSIGN_OR_RETURN(Value new_v,
relayout(builder, vector_operand, /*src=*/*lo,
/*dst=*/*li, target_shape));
FAILUREOR_ASSIGN_OR_RETURN(
Value new_v, relayout(builder, vector_operand, /*src=*/*lo,
/*dst=*/*li, hardware_generation, target_shape));
op.setOperand(idx, new_v);
}
return success();
}

struct RelayoutInsertionPass
: public impl::RelayoutInsertionPassBase<RelayoutInsertionPass> {
RelayoutInsertionPass(std::array<int64_t, 2> target_shape) {
RelayoutInsertionPass(int generation, std::array<int64_t, 2> target_shape) {
this->hardware_generation = generation;
this->sublane_count = target_shape[0];
this->lane_count = target_shape[1];
}
void runOnOperation() override {
// Fail if hardware_generation has not been set from the default value.
if (hardware_generation < 0) {
getOperation().emitError("hardware_generation must be set");
signalPassFailure();
return;
}
func::FuncOp func = getOperation();
auto result = func.walk([&](Operation *op) {
if (insertRelayout(*op, {sublane_count, lane_count}).failed()) {
if (insertRelayout(*op, hardware_generation, {sublane_count, lane_count})
.failed()) {
return WalkResult::interrupt();
}
return WalkResult::advance();
Expand All @@ -168,8 +188,9 @@ struct RelayoutInsertionPass
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
std::array<int64_t, 2> target_shape) {
return std::make_unique<RelayoutInsertionPass>(target_shape);
int hardware_generation, std::array<int64_t, 2> target_shape) {
return std::make_unique<RelayoutInsertionPass>(hardware_generation,
target_shape);
}

} // namespace mlir::tpu
13 changes: 4 additions & 9 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,22 +332,17 @@ def kernel(x, out):
dtype=[jnp.float32, jnp.bfloat16],
)
def test_i1_relayout_with_bitwidth_change(self, msk_dtype, dtype):
# TODO(jevinjiang): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
if not jtu.if_cloud_tpu_at_least(2025, 1, 25):
self.skipTest("Requires libtpu built after 2025-01-25")
shape = (129, 129)
msk_bitwidth = pallas_utils.dtype_bitwidth(msk_dtype)
bitwidth = pallas_utils.dtype_bitwidth(dtype)
if (
(jtu.get_tpu_version() > 5 and msk_bitwidth < 8)
or (jtu.get_tpu_version() == 5 and msk_bitwidth not in (8, 32))
or (jtu.get_tpu_version() < 5 and msk_bitwidth < 32)
):
if jtu.get_tpu_version() < 5 and msk_bitwidth < 32:
self.skipTest(
"Not implemented: cast vector to mask with bitwidth =="
f" {msk_bitwidth}"
)
if jtu.get_tpu_version() <= 5 and bitwidth < 32:
if jtu.get_tpu_version() < 5 and bitwidth < 32:
self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}")

@functools.partial(
Expand Down
Loading