diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 59a63c58ad10..75e502ef55e6 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -407,6 +407,7 @@ def get_dump_file_prefix() -> str: ( "func.func(tpu-relayout-insertion{" f" sublane-count={sl_cnt} lane-count={l_cnt}" + f" hardware-generation={hardware_generation}" "})" ), ] diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 845b690912e8..eca8d8336f98 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -389,6 +389,21 @@ def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> let hasVerifier = 1; } +def TPU_RelayoutOp : TPU_Op<"relayout", [SameOperandsAndResultType]> { + let arguments = (ins AnyType:$input); + let results = (outs AnyType:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; +} + +def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> { + let arguments = (ins + VectorOfNonZeroRankOf<[I1]>: $low, + VectorOfNonZeroRankOf<[I1]>: $high + ); + let results = (outs VectorOfNonZeroRankOf<[I1]>:$output); + let assemblyFormat = [{ $low `,` $high `,` attr-dict `:` type($low) `,` type($high) `->` type($output) }]; +} + def TPU_GatherOp : TPU_Op<"gather", [Pure]> { let arguments = (ins AnyVectorOfNonZeroRank:$source, @@ -891,6 +906,9 @@ def RelayoutInsertionPass : Pass<"tpu-relayout-insertion", "::mlir::func::FuncOp ]; let constructor = "::mlir::tpu::createRelayoutInsertionPass()"; let options = [ + // If hardware_generation is not set, the default value of -1 will crash on + // runOnOperation. + Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, ]; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 1e0418f518bb..0800a9e75087 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -84,6 +84,7 @@ std::unique_ptr> createInferVectorLayoutPass( const TpuTilingFlags &tpu_tiling_flags = {}); std::unique_ptr> createRelayoutInsertionPass( + int hardware_generation = -1, std::array target_shape = {8, 128}); std::unique_ptr> createApplyVectorLayoutPass( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 4539d6035709..de8ac2a3304a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2172,6 +2172,74 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op, return success(); } +LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + TPU_ASSERT_EQ_OP(op.getNumOperands(), 1); + TPU_ASSERT_EQ_OP(op.getNumResults(), 1); + TPU_ASSERT_EQ_OP(layouts_in.size(), 1); + TPU_ASSERT_EQ_OP(layouts_out.size(), 1); + TPU_ASSERT_OP(layouts_in[0].has_value()); + TPU_ASSERT_OP(layouts_out[0].has_value()); + const auto& in_layout = *layouts_in[0]; + const auto& out_layout = *layouts_out[0]; + auto realyout_op = cast(op); + auto in_bitwidth = in_layout.bitwidth(); + auto out_bitwidth = out_layout.bitwidth(); + auto vty = cast(realyout_op.getType()); + ImplicitLocOpBuilder builder(op.getLoc(), &op); + if (in_layout == out_layout) { + realyout_op.replaceAllUsesWith(realyout_op.getInput()); + realyout_op.erase(); + return success(); + } + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array vals, + disassemble(builder, in_layout, + cast>(realyout_op.getInput()), + ctx.target_shape, + /*use_implicit_shape=*/true)); + // Packing vector masks from 32-bit to 16-bit. + if (vty.getElementType() == builder.getI1Type() && in_bitwidth == 32 && + out_bitwidth == 16 && + in_layout.tiling()[0] == in_layout.packing() * ctx.target_shape[0] && + in_layout.tiling()[1] == ctx.target_shape[1] && + in_layout.tiling() == out_layout.tiling() && + in_layout.offsets() == out_layout.offsets() && + in_layout.implicit_dim() == out_layout.implicit_dim()) { + std::vector vmsks_shape(vals.dimensions().begin(), + vals.dimensions().end()); + *(vmsks_shape.end() - 1) = llvm::divideCeil(vmsks_shape.back(), 2); + xla::Array out_vmsks(vmsks_shape, nullptr); + SmallVector val_idx; + Value default_val = + getFullLikeVector(builder, cast>(*vals.begin()), + IntegerAttr::get(builder.getI1Type(), 0)); + out_vmsks.Each([&](absl::Span idx, Value *v) { + val_idx.assign(idx.begin(), idx.end()); + // TODO(jevinjiang): can be simplified when offset is replicated. + *(val_idx.end() - 1) *= 2; + Value low_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1) + ? vals(val_idx) + : default_val; + *(val_idx.end() - 1) += 1; + Value high_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1) + ? vals(val_idx) + : default_val; + const VectorType mask_ty = getNativeVregOrVmaskType( + builder.getI1Type(), in_bitwidth / 2, ctx.target_shape); + *v = builder.create(mask_ty, low_part, high_part); + }); + const RollVectorsOp rolled_op = + assemble(builder, vty, out_layout, out_vmsks, ctx.target_shape, + /*use_implicit_shape=*/true); + op.replaceAllUsesWith(rolled_op); + op.erase(); + return success(); + } + return op.emitOpError("Not implemented: unsupported layout change"); +} + // TODO(b/347016737): Deprecate tpu.rotate and only use tpu.dynamic_rotate. So // we do not need template for the op type and to explicitly force amount // argument to dynamic. @@ -4644,9 +4712,9 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, return success(); } -LogicalResult prng_random_bits_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { +LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { TPU_ASSERT_EQ_OP(layouts_in.size(), 0); TPU_ASSERT_EQ_OP(layouts_out.size(), 1); TPU_ASSERT_OP(layouts_out.front().has_value()); @@ -4711,7 +4779,8 @@ const llvm::StringMap &rules() { {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, {tpu::TraceOp::getOperationName(), tpu_trace_rule}, {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, - {tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule}, + {tpu::PRNGRandomBitsOp::getOperationName(), tpu_prng_random_bits_rule}, + {tpu::RelayoutOp::getOperationName(), tpu_relayout_rule}, {tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule}, {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, {vector::ExtractOp::getOperationName(), vector_extract_rule}, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc index a6099da88949..b88504e35068 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc @@ -31,12 +31,40 @@ namespace { FailureOr> relayout( OpBuilder &builder, TypedValue v, VectorLayout src, - VectorLayout dst, const std::array target_shape) { + VectorLayout dst, int hardware_generation, + const std::array target_shape) { // change bitwidth if (v.getType().getElementType() == builder.getI1Type() && // TODO(jevinjiang): for other relayout changes (tiling, offsets, implicit // dim), we currently rely on apply-vector-layout pass to do the relayout. src.bitwidth() != dst.bitwidth()) { + auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling()); + auto dst_bitwidth_layout = VectorLayout( + dst.bitwidth(), + { + src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0] + : LayoutOffset(), + src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1] + : LayoutOffset(), + }, + src.tiling(), src.implicit_dim()); + if (!dst_bitwidth_layout.isValid(target_shape)) { + return emitError(v.getLoc(), + "Not implemented: failed to infer valid layout during " + "relayout, got ") + << dst_bitwidth_layout; + } + // We might be able to pack mask directly. + // TODO(jevinjiang): Add support for 16bit -> 8bit mask packing. + if (src.bitwidth() == 32 && dst.bitwidth() == 16 && + // TODO(jevinjiang): support mask packing for non-native source tiling. + src.tiling()[0] == src.packing() * target_shape[0] && + src.tiling()[1] == target_shape[1]) { + auto relayout_op = + builder.create(v.getLoc(), v.getType(), v); + setLayout(relayout_op, src, dst_bitwidth_layout); + return cast>(relayout_op.getResult()); + } CHECK(llvm::isPowerOf2_32(src.bitwidth())); CHECK(llvm::isPowerOf2_32(dst.bitwidth())); auto make_vty = [&](int bitwidth) { @@ -56,25 +84,9 @@ FailureOr> relayout( }; auto src_int_vty = make_vty(src.bitwidth()); auto dst_int_vty = make_vty(dst.bitwidth()); - auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling()); // TODO(jevinjiang): Since dst_bitwidth_layout will be firstly used in the // extSI or truncI below, we can reuse the inferExt and inferTrunc from // infer-vector-layout pass. - auto dst_bitwidth_layout = VectorLayout( - dst.bitwidth(), - { - src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0] - : LayoutOffset(), - src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1] - : LayoutOffset(), - }, - src.tiling(), src.implicit_dim()); - if (!dst_bitwidth_layout.isValid(target_shape)) { - return emitError(v.getLoc(), - "Not implemented: failed to infer valid layout during " - "relayout, got ") - << dst_bitwidth_layout; - } auto ext_op = builder.create(v.getLoc(), src_int_vty, v); setLayout(ext_op, src, src); @@ -98,7 +110,7 @@ FailureOr> relayout( // TODO(jevinjiang): make relayout to an op so we don't need decide when to // relayout in apply-vector-layout pass. -LogicalResult insertRelayout(Operation &op, +LogicalResult insertRelayout(Operation &op, int hardware_generation, const std::array target_shape) { FAILUREOR_ASSIGN_OR_RETURN(const SmallVector in_layouts, getInLayouts(op, target_shape)); @@ -136,9 +148,9 @@ LogicalResult insertRelayout(Operation &op, continue; } OpBuilder builder(&op); - FAILUREOR_ASSIGN_OR_RETURN(Value new_v, - relayout(builder, vector_operand, /*src=*/*lo, - /*dst=*/*li, target_shape)); + FAILUREOR_ASSIGN_OR_RETURN( + Value new_v, relayout(builder, vector_operand, /*src=*/*lo, + /*dst=*/*li, hardware_generation, target_shape)); op.setOperand(idx, new_v); } return success(); @@ -146,14 +158,22 @@ LogicalResult insertRelayout(Operation &op, struct RelayoutInsertionPass : public impl::RelayoutInsertionPassBase { - RelayoutInsertionPass(std::array target_shape) { + RelayoutInsertionPass(int generation, std::array target_shape) { + this->hardware_generation = generation; this->sublane_count = target_shape[0]; this->lane_count = target_shape[1]; } void runOnOperation() override { + // Fail if hardware_generation has not been set from the default value. + if (hardware_generation < 0) { + getOperation().emitError("hardware_generation must be set"); + signalPassFailure(); + return; + } func::FuncOp func = getOperation(); auto result = func.walk([&](Operation *op) { - if (insertRelayout(*op, {sublane_count, lane_count}).failed()) { + if (insertRelayout(*op, hardware_generation, {sublane_count, lane_count}) + .failed()) { return WalkResult::interrupt(); } return WalkResult::advance(); @@ -168,8 +188,9 @@ struct RelayoutInsertionPass } // namespace std::unique_ptr> createRelayoutInsertionPass( - std::array target_shape) { - return std::make_unique(target_shape); + int hardware_generation, std::array target_shape) { + return std::make_unique(hardware_generation, + target_shape); } } // namespace mlir::tpu \ No newline at end of file diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 29fd741814ba..a1102b461978 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -332,22 +332,17 @@ def kernel(x, out): dtype=[jnp.float32, jnp.bfloat16], ) def test_i1_relayout_with_bitwidth_change(self, msk_dtype, dtype): - # TODO(jevinjiang): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built after 2024-12-19") + if not jtu.if_cloud_tpu_at_least(2025, 1, 25): + self.skipTest("Requires libtpu built after 2025-01-25") shape = (129, 129) msk_bitwidth = pallas_utils.dtype_bitwidth(msk_dtype) bitwidth = pallas_utils.dtype_bitwidth(dtype) - if ( - (jtu.get_tpu_version() > 5 and msk_bitwidth < 8) - or (jtu.get_tpu_version() == 5 and msk_bitwidth not in (8, 32)) - or (jtu.get_tpu_version() < 5 and msk_bitwidth < 32) - ): + if jtu.get_tpu_version() < 5 and msk_bitwidth < 32: self.skipTest( "Not implemented: cast vector to mask with bitwidth ==" f" {msk_bitwidth}" ) - if jtu.get_tpu_version() <= 5 and bitwidth < 32: + if jtu.get_tpu_version() < 5 and bitwidth < 32: self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}") @functools.partial(