Skip to content

Commit 6502f55

Browse files
authored
[CIR][ThroughMLIR] Templatize unary math op lowerings. (#1557)
A lot of the unary math op lowerings follow the same template -- we can templatize this to remove redundant code and make things a little more neater. (Similar to what we do [here](https://github.com/llvm/clangir/blob/e4b8a48fb4d9a72a85e38f5439bcfb0673b4bea2/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp#L502)) I've checked all existing LIT tests via `ninja clang-check-cir` and they seem to be passing fine.
1 parent 84b4603 commit 6502f55

File tree

1 file changed

+37
-193
lines changed

1 file changed

+37
-193
lines changed

Diff for: clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+37-193
Original file line numberDiff line numberDiff line change
@@ -275,172 +275,52 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
275275
}
276276
};
277277

278-
class CIRACosOpLowering : public mlir::OpConversionPattern<cir::ACosOp> {
279-
public:
280-
using OpConversionPattern<cir::ACosOp>::OpConversionPattern;
281-
282-
mlir::LogicalResult
283-
matchAndRewrite(cir::ACosOp op, OpAdaptor adaptor,
284-
mlir::ConversionPatternRewriter &rewriter) const override {
285-
rewriter.replaceOpWithNewOp<mlir::math::AcosOp>(op, adaptor.getSrc());
286-
return mlir::LogicalResult::success();
287-
}
288-
};
289-
290-
class CIRATanOpLowering : public mlir::OpConversionPattern<cir::ATanOp> {
291-
public:
292-
using OpConversionPattern<cir::ATanOp>::OpConversionPattern;
293-
294-
mlir::LogicalResult
295-
matchAndRewrite(cir::ATanOp op, OpAdaptor adaptor,
296-
mlir::ConversionPatternRewriter &rewriter) const override {
297-
rewriter.replaceOpWithNewOp<mlir::math::AtanOp>(op, adaptor.getSrc());
298-
return mlir::LogicalResult::success();
299-
}
300-
};
301-
302-
class CIRCosOpLowering : public mlir::OpConversionPattern<cir::CosOp> {
303-
public:
304-
using OpConversionPattern<cir::CosOp>::OpConversionPattern;
305-
306-
mlir::LogicalResult
307-
matchAndRewrite(cir::CosOp op, OpAdaptor adaptor,
308-
mlir::ConversionPatternRewriter &rewriter) const override {
309-
rewriter.replaceOpWithNewOp<mlir::math::CosOp>(op, adaptor.getSrc());
310-
return mlir::LogicalResult::success();
311-
}
312-
};
313-
314-
class CIRTanOpLowering : public mlir::OpConversionPattern<cir::TanOp> {
315-
public:
316-
using OpConversionPattern<cir::TanOp>::OpConversionPattern;
317-
318-
mlir::LogicalResult
319-
matchAndRewrite(cir::TanOp op, OpAdaptor adaptor,
320-
mlir::ConversionPatternRewriter &rewriter) const override {
321-
rewriter.replaceOpWithNewOp<mlir::math::TanOp>(op, adaptor.getSrc());
322-
return mlir::LogicalResult::success();
323-
}
324-
};
325-
326-
class CIRSqrtOpLowering : public mlir::OpConversionPattern<cir::SqrtOp> {
327-
public:
328-
using mlir::OpConversionPattern<cir::SqrtOp>::OpConversionPattern;
278+
/// Converts CIR unary math ops (e.g., cir::SinOp) to their MLIR equivalents
279+
/// (e.g., math::SinOp) using a generic template to avoid redundant boilerplate
280+
/// matchAndRewrite definitions.
329281

330-
mlir::LogicalResult
331-
matchAndRewrite(cir::SqrtOp op, OpAdaptor adaptor,
332-
mlir::ConversionPatternRewriter &rewriter) const override {
333-
rewriter.replaceOpWithNewOp<mlir::math::SqrtOp>(op, adaptor.getSrc());
334-
return mlir::LogicalResult::success();
335-
}
336-
};
337-
338-
class CIRFAbsOpLowering : public mlir::OpConversionPattern<cir::FAbsOp> {
339-
public:
340-
using mlir::OpConversionPattern<cir::FAbsOp>::OpConversionPattern;
341-
342-
mlir::LogicalResult
343-
matchAndRewrite(cir::FAbsOp op, OpAdaptor adaptor,
344-
mlir::ConversionPatternRewriter &rewriter) const override {
345-
rewriter.replaceOpWithNewOp<mlir::math::AbsFOp>(op, adaptor.getSrc());
346-
return mlir::LogicalResult::success();
347-
}
348-
};
349-
class CIRAbsOpLowering : public mlir::OpConversionPattern<cir::AbsOp> {
350-
public:
351-
using mlir::OpConversionPattern<cir::AbsOp>::OpConversionPattern;
352-
353-
mlir::LogicalResult
354-
matchAndRewrite(cir::AbsOp op, OpAdaptor adaptor,
355-
mlir::ConversionPatternRewriter &rewriter) const override {
356-
rewriter.replaceOpWithNewOp<mlir::math::AbsIOp>(op, adaptor.getSrc());
357-
return mlir::LogicalResult::success();
358-
}
359-
};
360-
361-
class CIRFloorOpLowering : public mlir::OpConversionPattern<cir::FloorOp> {
362-
public:
363-
using mlir::OpConversionPattern<cir::FloorOp>::OpConversionPattern;
364-
365-
mlir::LogicalResult
366-
matchAndRewrite(cir::FloorOp op, OpAdaptor adaptor,
367-
mlir::ConversionPatternRewriter &rewriter) const override {
368-
rewriter.replaceOpWithNewOp<mlir::math::FloorOp>(op, adaptor.getSrc());
369-
return mlir::LogicalResult::success();
370-
}
371-
};
372-
373-
class CIRCeilOpLowering : public mlir::OpConversionPattern<cir::CeilOp> {
374-
public:
375-
using mlir::OpConversionPattern<cir::CeilOp>::OpConversionPattern;
376-
377-
mlir::LogicalResult
378-
matchAndRewrite(cir::CeilOp op, OpAdaptor adaptor,
379-
mlir::ConversionPatternRewriter &rewriter) const override {
380-
rewriter.replaceOpWithNewOp<mlir::math::CeilOp>(op, adaptor.getSrc());
381-
return mlir::LogicalResult::success();
382-
}
383-
};
384-
385-
class CIRLog10OpLowering : public mlir::OpConversionPattern<cir::Log10Op> {
386-
public:
387-
using mlir::OpConversionPattern<cir::Log10Op>::OpConversionPattern;
388-
389-
mlir::LogicalResult
390-
matchAndRewrite(cir::Log10Op op, OpAdaptor adaptor,
391-
mlir::ConversionPatternRewriter &rewriter) const override {
392-
rewriter.replaceOpWithNewOp<mlir::math::Log10Op>(op, adaptor.getSrc());
393-
return mlir::LogicalResult::success();
394-
}
395-
};
396-
397-
class CIRLogOpLowering : public mlir::OpConversionPattern<cir::LogOp> {
398-
public:
399-
using mlir::OpConversionPattern<cir::LogOp>::OpConversionPattern;
400-
401-
mlir::LogicalResult
402-
matchAndRewrite(cir::LogOp op, OpAdaptor adaptor,
403-
mlir::ConversionPatternRewriter &rewriter) const override {
404-
rewriter.replaceOpWithNewOp<mlir::math::LogOp>(op, adaptor.getSrc());
405-
return mlir::LogicalResult::success();
406-
}
407-
};
408-
409-
class CIRLog2OpLowering : public mlir::OpConversionPattern<cir::Log2Op> {
410-
public:
411-
using mlir::OpConversionPattern<cir::Log2Op>::OpConversionPattern;
412-
413-
mlir::LogicalResult
414-
matchAndRewrite(cir::Log2Op op, OpAdaptor adaptor,
415-
mlir::ConversionPatternRewriter &rewriter) const override {
416-
rewriter.replaceOpWithNewOp<mlir::math::Log2Op>(op, adaptor.getSrc());
417-
return mlir::LogicalResult::success();
418-
}
419-
};
420-
421-
class CIRRoundOpLowering : public mlir::OpConversionPattern<cir::RoundOp> {
282+
template <typename CIROp, typename MLIROp>
283+
class CIRUnaryMathOpLowering : public mlir::OpConversionPattern<CIROp> {
422284
public:
423-
using mlir::OpConversionPattern<cir::RoundOp>::OpConversionPattern;
285+
using mlir::OpConversionPattern<CIROp>::OpConversionPattern;
424286

425287
mlir::LogicalResult
426-
matchAndRewrite(cir::RoundOp op, OpAdaptor adaptor,
288+
matchAndRewrite(CIROp op,
289+
typename mlir::OpConversionPattern<CIROp>::OpAdaptor adaptor,
427290
mlir::ConversionPatternRewriter &rewriter) const override {
428-
rewriter.replaceOpWithNewOp<mlir::math::RoundOp>(op, adaptor.getSrc());
291+
rewriter.replaceOpWithNewOp<MLIROp>(op, adaptor.getSrc());
429292
return mlir::LogicalResult::success();
430293
}
431294
};
432295

433-
class CIRExpOpLowering : public mlir::OpConversionPattern<cir::ExpOp> {
434-
public:
435-
using mlir::OpConversionPattern<cir::ExpOp>::OpConversionPattern;
436-
437-
mlir::LogicalResult
438-
matchAndRewrite(cir::ExpOp op, OpAdaptor adaptor,
439-
mlir::ConversionPatternRewriter &rewriter) const override {
440-
rewriter.replaceOpWithNewOp<mlir::math::ExpOp>(op, adaptor.getSrc());
441-
return mlir::LogicalResult::success();
442-
}
443-
};
296+
using CIRASinOpLowering =
297+
CIRUnaryMathOpLowering<cir::ASinOp, mlir::math::AsinOp>;
298+
using CIRSinOpLowering = CIRUnaryMathOpLowering<cir::SinOp, mlir::math::SinOp>;
299+
using CIRExp2OpLowering =
300+
CIRUnaryMathOpLowering<cir::Exp2Op, mlir::math::Exp2Op>;
301+
using CIRExpOpLowering = CIRUnaryMathOpLowering<cir::ExpOp, mlir::math::ExpOp>;
302+
using CIRRoundOpLowering =
303+
CIRUnaryMathOpLowering<cir::RoundOp, mlir::math::RoundOp>;
304+
using CIRLog2OpLowering =
305+
CIRUnaryMathOpLowering<cir::Log2Op, mlir::math::Log2Op>;
306+
using CIRLogOpLowering = CIRUnaryMathOpLowering<cir::LogOp, mlir::math::LogOp>;
307+
using CIRLog10OpLowering =
308+
CIRUnaryMathOpLowering<cir::Log10Op, mlir::math::Log10Op>;
309+
using CIRCeilOpLowering =
310+
CIRUnaryMathOpLowering<cir::CeilOp, mlir::math::CeilOp>;
311+
using CIRFloorOpLowering =
312+
CIRUnaryMathOpLowering<cir::FloorOp, mlir::math::FloorOp>;
313+
using CIRAbsOpLowering = CIRUnaryMathOpLowering<cir::AbsOp, mlir::math::AbsIOp>;
314+
using CIRFAbsOpLowering =
315+
CIRUnaryMathOpLowering<cir::FAbsOp, mlir::math::AbsFOp>;
316+
using CIRSqrtOpLowering =
317+
CIRUnaryMathOpLowering<cir::SqrtOp, mlir::math::SqrtOp>;
318+
using CIRCosOpLowering = CIRUnaryMathOpLowering<cir::CosOp, mlir::math::CosOp>;
319+
using CIRATanOpLowering =
320+
CIRUnaryMathOpLowering<cir::ATanOp, mlir::math::AtanOp>;
321+
using CIRACosOpLowering =
322+
CIRUnaryMathOpLowering<cir::ACosOp, mlir::math::AcosOp>;
323+
using CIRTanOpLowering = CIRUnaryMathOpLowering<cir::TanOp, mlir::math::TanOp>;
444324

445325
class CIRShiftOpLowering : public mlir::OpConversionPattern<cir::ShiftOp> {
446326
public:
@@ -475,42 +355,6 @@ class CIRShiftOpLowering : public mlir::OpConversionPattern<cir::ShiftOp> {
475355
}
476356
};
477357

478-
class CIRExp2OpLowering : public mlir::OpConversionPattern<cir::Exp2Op> {
479-
public:
480-
using mlir::OpConversionPattern<cir::Exp2Op>::OpConversionPattern;
481-
482-
mlir::LogicalResult
483-
matchAndRewrite(cir::Exp2Op op, OpAdaptor adaptor,
484-
mlir::ConversionPatternRewriter &rewriter) const override {
485-
rewriter.replaceOpWithNewOp<mlir::math::Exp2Op>(op, adaptor.getSrc());
486-
return mlir::LogicalResult::success();
487-
}
488-
};
489-
490-
class CIRSinOpLowering : public mlir::OpConversionPattern<cir::SinOp> {
491-
public:
492-
using mlir::OpConversionPattern<cir::SinOp>::OpConversionPattern;
493-
494-
mlir::LogicalResult
495-
matchAndRewrite(cir::SinOp op, OpAdaptor adaptor,
496-
mlir::ConversionPatternRewriter &rewriter) const override {
497-
rewriter.replaceOpWithNewOp<mlir::math::SinOp>(op, adaptor.getSrc());
498-
return mlir::LogicalResult::success();
499-
}
500-
};
501-
502-
class CIRASinOpLowering : public mlir::OpConversionPattern<cir::ASinOp> {
503-
public:
504-
using mlir::OpConversionPattern<cir::ASinOp>::OpConversionPattern;
505-
506-
mlir::LogicalResult
507-
matchAndRewrite(cir::ASinOp op, OpAdaptor adaptor,
508-
mlir::ConversionPatternRewriter &rewriter) const override {
509-
rewriter.replaceOpWithNewOp<mlir::math::AsinOp>(op, adaptor.getSrc());
510-
return mlir::LogicalResult::success();
511-
}
512-
};
513-
514358
template <typename CIROp, typename MLIROp>
515359
class CIRCountZerosBitOpLowering : public mlir::OpConversionPattern<CIROp> {
516360
public:

0 commit comments

Comments
 (0)