Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dawn/src/dawn/Compiler/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "dawn/Compiler/Driver.h"
#include "dawn/CodeGen/Driver.h"
#include "dawn/CodeGen/TranslationUnit.h"
#include "dawn/Optimizer/PassSimplifyStatements.h"
#include "dawn/SIR/SIR.h"
#include "dawn/Support/Iterator.h"
#include "dawn/Support/Logging.h"
Expand All @@ -31,6 +32,7 @@
#include "dawn/Optimizer/PassPrintStencilGraph.h"
#include "dawn/Optimizer/PassRemoveScalars.h"
#include "dawn/Optimizer/PassSSA.h"
#include "dawn/Optimizer/PassSimplifyStatements.h"
#include "dawn/Optimizer/PassSetBlockSize.h"
#include "dawn/Optimizer/PassSetBoundaryCondition.h"
#include "dawn/Optimizer/PassSetCaches.h"
Expand Down Expand Up @@ -81,6 +83,7 @@ run(const std::shared_ptr<SIR>& stencilIR, const std::list<PassGroup>& groups,
using MultistageSplitStrategy = PassMultiStageSplitter::MultiStageSplittingStrategy;

// required passes to have proper, parallelized IR
optimizer.pushBackPass<PassSimplifyStatements>();
optimizer.pushBackPass<PassInlining>(PassInlining::InlineStrategy::InlineProcedures);
optimizer.pushBackPass<PassFieldVersioning>();
optimizer.pushBackPass<PassMultiStageSplitter>(
Expand Down
2 changes: 2 additions & 0 deletions dawn/src/dawn/Optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ add_library(DawnOptimizer
PassSetStageName.h
PassSetSyncStage.cpp
PassSetSyncStage.h
PassSimplifyStatements.h
PassSimplifyStatements.cpp
PassStageSplitAllStatements.cpp
PassStageSplitAllStatements.h
PassSSA.cpp
Expand Down
106 changes: 106 additions & 0 deletions dawn/src/dawn/Optimizer/PassSimplifyStatements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
//===--------------------------------------------------------------------------------*- C++ -*-===//
// _
// | |
// __| | __ ___ ___ ___
// / _` |/ _` \ \ /\ / / '_ |
// | (_| | (_| |\ V V /| | | |
// \__,_|\__,_| \_/\_/ |_| |_| - Compiler Toolchain
//
//
// This file is distributed under the MIT License (MIT).
// See LICENSE.txt for details.
//
//===------------------------------------------------------------------------------------------===//

#include "PassSimplifyStatements.h"
#include "dawn/AST/ASTExpr.h"
#include "dawn/IIR/ASTFwd.h"
#include "dawn/IIR/AccessComputation.h"
#include "dawn/IIR/DoMethod.h"
#include "dawn/IIR/StencilInstantiation.h"
#include "dawn/Support/Type.h"
#include <memory>

namespace dawn {

namespace {
class IncrementDecrementReplacer : public ast::ASTVisitorPostOrder {
std::vector<std::shared_ptr<ast::Stmt>> statements_;

public:
std::shared_ptr<iir::Expr>
postVisitNode(std::shared_ptr<iir::UnaryOperator> const& unaryOp) override {
auto sourceLoc = unaryOp->getSourceLocation();
std::shared_ptr<ast::BinaryOperator> binOp;
if(unaryOp->getOp() == "++") {
binOp = std::make_shared<ast::BinaryOperator>(
unaryOp->getOperand()->clone(), "+",
std::make_shared<ast::LiteralAccessExpr>("1", BuiltinTypeID::Integer, sourceLoc),
sourceLoc);
} else if(unaryOp->getOp() == "--") {
binOp = std::make_shared<ast::BinaryOperator>(
unaryOp->getOperand()->clone(), "-",
std::make_shared<ast::LiteralAccessExpr>("1", BuiltinTypeID::Integer, sourceLoc),
sourceLoc);
Comment on lines +35 to +44
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't work in this case:

if (i == 0 && ++i == 1) { /* ... */ }

will become:

i = i + 1;
if (i == 0 && i == 1) { /* ... */ }

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting one! Can we resolve this (without introducing a temporary variable)?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A valid transformation output would be:

if (i == 0) {
  i = i + 1;
  if (i == 1) { /* ... */ }
}

But I don't think that such an output would be suitable for a general transformation algorithm.

Note: this isn't equivalent if SIR also has C++'s short-circuit semantics (it would also introduce more ops):

i = i + 1;
if (i - 1 == 0 && i == 1) { /* ... */ }

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is not an easy problem to solve, let's leave this open for the next syntax workshop.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. But do we then merge a pass that only works for a subset of SIR and could change semantics if SIR outside that subset is given?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenWeber42 I wouldn't merge it for now. Let's briefly present the problem at the review meeting. A quick solution could be saying that increment and decrement ops can't be nested in expressions (i.e. only allowed in statements like i++; i--;) from the beginning (frontends/SIR). If this solution is not widely accepted, then this problem will need to be discussed more in detail, because it seems that the only way to solve it is to do the heavy transformation that you showed.

} else {
return unaryOp;
}
DAWN_ASSERT(unaryOp->getOperand()->getKind() == ast::Expr::Kind::FieldAccessExpr ||
unaryOp->getOperand()->getKind() == ast::Expr::Kind::VarAccessExpr);
auto newAssignmentExpr = std::make_shared<ast::AssignmentExpr>(unaryOp->getOperand()->clone(),
binOp, "=", sourceLoc);

statements_.push_back(iir::makeExprStmt(newAssignmentExpr, sourceLoc));

return unaryOp->getOperand();
}
const std::vector<std::shared_ptr<ast::Stmt>>& getReplacements() { return statements_; }
};
} // namespace

bool PassSimplifyStatements::run(
const std::shared_ptr<iir::StencilInstantiation>& stencilInstantiation) {
for(const auto& doMethod : iterateIIROver<iir::DoMethod>(*stencilInstantiation->getIIR())) {
for(auto stmtIt = doMethod->getAST().getStatements().begin();
stmtIt != doMethod->getAST().getStatements().end();) {
// Compound assignment
if(const auto& exprStmt = std::dynamic_pointer_cast<iir::ExprStmt>(*stmtIt)) {
auto sourceLoc = exprStmt->getSourceLocation();
if(const auto& assignmentExpr =
std::dynamic_pointer_cast<iir::AssignmentExpr>(exprStmt->getExpr())) {
if(assignmentExpr->getOp() != "=") {
auto binOp = std::make_shared<ast::BinaryOperator>(
assignmentExpr->getLeft()->clone(), assignmentExpr->getOp().substr(0, 1),
assignmentExpr->getRight(), sourceLoc);
auto newAssignmentExpr = std::make_shared<ast::AssignmentExpr>(
assignmentExpr->getLeft(), binOp, "=", sourceLoc);
exprStmt->getExpr() = newAssignmentExpr;
}
}
Comment on lines +66 to +79
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to ignore compoung assignments in nested block statements (e.g., nested ifs)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, need to fix this.

}

// Increment/decrement ops (can be nested inside expression tree)
IncrementDecrementReplacer replacer;
doMethod->getAST().substitute(stmtIt, (*stmtIt)->acceptAndReplace(replacer));
stmtIt = doMethod->getAST().insert(stmtIt, replacer.getReplacements().begin(),
replacer.getReplacements().end());
std::advance(stmtIt, replacer.getReplacements().size());
// Substitution might have left an useless statement accessing a field/variable.
if(const auto& exprStmt = std::dynamic_pointer_cast<iir::ExprStmt>(*stmtIt)) {
if(exprStmt->getExpr()->getKind() == ast::Expr::Kind::FieldAccessExpr ||
exprStmt->getExpr()->getKind() == ast::Expr::Kind::VarAccessExpr) {
stmtIt = doMethod->getAST().erase(stmtIt);
continue;
}
}

++stmtIt;
}
// Recompute the accesses metadata of all statements (new statements and changed statements
// require this)
computeAccesses(stencilInstantiation->getMetaData(), doMethod->getAST().getStatements());
}
return true;
}

} // namespace dawn
35 changes: 35 additions & 0 deletions dawn/src/dawn/Optimizer/PassSimplifyStatements.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===--------------------------------------------------------------------------------*- C++ -*-===//
// _
// | |
// __| | __ ___ ___ ___
// / _` |/ _` \ \ /\ / / '_ |
// | (_| | (_| |\ V V /| | | |
// \__,_|\__,_| \_/\_/ |_| |_| - Compiler Toolchain
//
//
// This file is distributed under the MIT License (MIT).
// See LICENSE.txt for details.
//
//===------------------------------------------------------------------------------------------===//

#pragma once

#include "Pass.h"

namespace dawn {

/// @brief PassSimplifyStatements converts "advanced" statements (compound assignments and
/// increment, decrement ops) into their extended equivalent forms, to have a syntax which is
/// simpler to anaylise.
/// @ingroup optimizer
///
/// This pass is necessary to generate legal IIR
class PassSimplifyStatements : public Pass {
public:
PassSimplifyStatements(OptimizerContext& context) : Pass(context, "PassSimplifyStatements") {}

/// @brief Pass implementation
bool run(const std::shared_ptr<iir::StencilInstantiation>& stencilInstantiation) override;
};

} // namespace dawn
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ using namespace gridtools::dawn;

namespace dawn_generated{
namespace cuda{
__global__ void __launch_bounds__(32) tridiagonal_solve_stencil_stencil49_ms105_kernel(const int isize, const int jsize, const int ksize, const int stride_111_1, const int stride_111_2, ::dawn::float_type * const a, ::dawn::float_type * const b, ::dawn::float_type * const c, ::dawn::float_type * const d) {
__global__ void __launch_bounds__(32) tridiagonal_solve_stencil_stencil49_ms108_kernel(const int isize, const int jsize, const int ksize, const int stride_111_1, const int stride_111_2, ::dawn::float_type * const a, ::dawn::float_type * const b, ::dawn::float_type * const c, ::dawn::float_type * const d) {

// Start kernel
::dawn::float_type c_kcache[2];
Expand Down Expand Up @@ -155,7 +155,7 @@ if(iblock >= 0 && iblock <= block_size_i -1 + 0 && jblock >= 0 && jblock <= bloc

// Final flush of kcaches
}
__global__ void __launch_bounds__(32) tridiagonal_solve_stencil_stencil49_ms106_kernel(const int isize, const int jsize, const int ksize, const int stride_111_1, const int stride_111_2, ::dawn::float_type * const c, ::dawn::float_type * const d) {
__global__ void __launch_bounds__(32) tridiagonal_solve_stencil_stencil49_ms109_kernel(const int isize, const int jsize, const int ksize, const int stride_111_1, const int stride_111_2, ::dawn::float_type * const c, ::dawn::float_type * const d) {

// Start kernel
::dawn::float_type d_kcache[2];
Expand Down Expand Up @@ -212,7 +212,7 @@ if(iblock >= 0 && iblock <= block_size_i -1 + 0 && jblock >= 0 && jblock <= bloc
if(iblock >= 0 && iblock <= block_size_i -1 + 0 && jblock >= 0 && jblock <= block_size_j -1 + 0) {
d_kcache[0] =d[idx111];
} if(iblock >= 0 && iblock <= block_size_i -1 + 0 && jblock >= 0 && jblock <= block_size_j -1 + 0) {
d_kcache[0] -= (__ldg(&(c[idx111])) * d_kcache[1]);
d_kcache[0] = (d_kcache[0] - (__ldg(&(c[idx111])) * d_kcache[1]));
}
// Flush of kcaches

Expand Down Expand Up @@ -283,7 +283,7 @@ class tridiagonal_solve_stencil {
const unsigned int nby = (ny + 1 - 1) / 1;
const unsigned int nbz = 1;
dim3 blocks(nbx, nby, nbz);
tridiagonal_solve_stencil_stencil49_ms105_kernel<<<blocks, threads>>>(nx,ny,nz,a_ds.strides()[1],a_ds.strides()[2],(a.data()+a_ds.get_storage_info_ptr()->index(a.begin<0>(), a.begin<1>(),0 )),(b.data()+b_ds.get_storage_info_ptr()->index(b.begin<0>(), b.begin<1>(),0 )),(c.data()+c_ds.get_storage_info_ptr()->index(c.begin<0>(), c.begin<1>(),0 )),(d.data()+d_ds.get_storage_info_ptr()->index(d.begin<0>(), d.begin<1>(),0 )));
tridiagonal_solve_stencil_stencil49_ms108_kernel<<<blocks, threads>>>(nx,ny,nz,a_ds.strides()[1],a_ds.strides()[2],(a.data()+a_ds.get_storage_info_ptr()->index(a.begin<0>(), a.begin<1>(),0 )),(b.data()+b_ds.get_storage_info_ptr()->index(b.begin<0>(), b.begin<1>(),0 )),(c.data()+c_ds.get_storage_info_ptr()->index(c.begin<0>(), c.begin<1>(),0 )),(d.data()+d_ds.get_storage_info_ptr()->index(d.begin<0>(), d.begin<1>(),0 )));
};
{;
gridtools::data_view<storage_ijk_t> c= gridtools::make_device_view(c_ds);
Expand All @@ -296,7 +296,7 @@ class tridiagonal_solve_stencil {
const unsigned int nby = (ny + 1 - 1) / 1;
const unsigned int nbz = 1;
dim3 blocks(nbx, nby, nbz);
tridiagonal_solve_stencil_stencil49_ms106_kernel<<<blocks, threads>>>(nx,ny,nz,c_ds.strides()[1],c_ds.strides()[2],(c.data()+c_ds.get_storage_info_ptr()->index(c.begin<0>(), c.begin<1>(),0 )),(d.data()+d_ds.get_storage_info_ptr()->index(d.begin<0>(), d.begin<1>(),0 )));
tridiagonal_solve_stencil_stencil49_ms109_kernel<<<blocks, threads>>>(nx,ny,nz,c_ds.strides()[1],c_ds.strides()[2],(c.data()+c_ds.get_storage_info_ptr()->index(c.begin<0>(), c.begin<1>(),0 )),(d.data()+d_ds.get_storage_info_ptr()->index(d.begin<0>(), d.begin<1>(),0 )));
};

// stopping timers
Expand Down
1 change: 1 addition & 0 deletions dawn/test/unit-test/dawn/Optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_executable(${executable}
TestPassSetCaches.cpp
TestPassSetNonTempCaches.cpp
TestPassSetStageLocationType.cpp
TestPassSimplifyStatements.cpp
TestPassStageMerger.cpp
TestPassStageSplitAllStatements.cpp
TestPassStageReordering.cpp
Expand Down
119 changes: 119 additions & 0 deletions dawn/test/unit-test/dawn/Optimizer/TestPassSimplifyStatements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
//===--------------------------------------------------------------------------------*- C++ -*-===//
// _
// | |
// __| | __ ___ ___ ___
// / _` |/ _` \ \ /\ / / '_ |
// | (_| | (_| |\ V V /| | | |
// \__,_|\__,_| \_/\_/ |_| |_| - Compiler Toolchain
//
//
// This file is distributed under the MIT License (MIT).
// See LICENSE.txt for details.
//
//===------------------------------------------------------------------------------------------===//

#include "dawn/AST/ASTStringifier.h"
#include "dawn/IIR/ASTFwd.h"
#include "dawn/IIR/IIR.h"
#include "dawn/IIR/StencilInstantiation.h"
#include "dawn/Optimizer/OptimizerContext.h"
#include "dawn/Optimizer/PassSimplifyStatements.h"
#include "dawn/Serialization/IIRSerializer.h"
#include "dawn/Unittest/ASTConstructionAliases.h"
#include "dawn/Unittest/UnittestUtils.h"

#include <fstream>
#include <gtest/gtest.h>
#include <memory>

using namespace dawn;
using namespace astgen;

namespace {

std::shared_ptr<iir::StencilInstantiation> initializeInstantiation(const std::string& filename) {
UIDGenerator::getInstance()->reset();
auto instantiation = IIRSerializer::deserialize(filename);
DiagnosticsEngine diag;
OptimizerContext context(diag, {}, {{instantiation->getName(), instantiation}});

PassSimplifyStatements pass(context);
pass.run(instantiation);
EXPECT_TRUE(!diag.hasErrors());

return instantiation;
}

TEST(TestPassSimplifyStatements, CompoundStatement) {
// b += a;
// d -= c;
auto instantiation = initializeInstantiation("input/test_simplify_statements_compound_statement.iir");
auto const& firstStmt = getNthStmt(getFirstDoMethod(instantiation), 0);
ASSERT_TRUE(firstStmt->equals(expr(assign(field("b"), binop(field("b"), "+", field("a")))).get(),
/*compareData = */ false));
auto const& secondStmt = getNthStmt(getFirstDoMethod(instantiation), 1);
ASSERT_TRUE(secondStmt->equals(expr(assign(field("d"), binop(field("d"), "-", field("c")))).get(),
/*compareData = */ false));
}

TEST(TestPassSimplifyStatements, IncrementDecrement) {
// int b = d;
// int c = d;
// --b;
// ++c;
// a = c + b;
auto instantiation = initializeInstantiation("input/test_simplify_statements_increment_decrement.iir");
ASSERT_EQ(5, getFirstDoMethod(instantiation).getAST().getStatements().size());
auto const& firstStmt = getNthStmt(getFirstDoMethod(instantiation), 2);
ASSERT_TRUE(firstStmt->equals(expr(assign(var("b"), binop(var("b"), "-", lit(1)))).get(),
/*compareData = */ false));
auto const& secondStmt = getNthStmt(getFirstDoMethod(instantiation), 3);
ASSERT_TRUE(secondStmt->equals(expr(assign(var("c"), binop(var("c"), "+", lit(1)))).get(),
/*compareData = */ false));
}

TEST(TestPassSimplifyStatements, IncrementNested) {
// int b = c;
// a = (++b) + 1;
auto instantiation = initializeInstantiation("input/test_simplify_statements_increment_nested.iir");
ASSERT_EQ(3, getFirstDoMethod(instantiation).getAST().getStatements().size());
auto const& firstStmt = getNthStmt(getFirstDoMethod(instantiation), 1);
ASSERT_TRUE(firstStmt->equals(expr(assign(var("b"), binop(var("b"), "+", lit(1)))).get(),
/*compareData = */ false));
auto const& secondStmt = getNthStmt(getFirstDoMethod(instantiation), 2);
ASSERT_TRUE(secondStmt->equals(expr(assign(field("a"), binop(var("b"), "+", lit(1)))).get(),
/*compareData = */ false));
}

TEST(TestPassSimplifyStatements, MixMultipleNested) {
// int b = d;
// int c = d;
// a += ++b + (1 + --c);
// a *= ++c * --b;
auto instantiation = initializeInstantiation("input/test_simplify_statements_mix_multiple_nested.iir");
ASSERT_EQ(8, getFirstDoMethod(instantiation).getAST().getStatements().size());
auto const& firstStmt = getNthStmt(getFirstDoMethod(instantiation), 2);
ASSERT_TRUE(firstStmt->equals(expr(assign(var("b"), binop(var("b"), "+", lit(1)))).get(),
/*compareData = */ false));
auto const& secondStmt = getNthStmt(getFirstDoMethod(instantiation), 3);
ASSERT_TRUE(secondStmt->equals(expr(assign(var("c"), binop(var("c"), "-", lit(1)))).get(),
/*compareData = */ false));
auto const& thirdStmt = getNthStmt(getFirstDoMethod(instantiation), 4);
ASSERT_TRUE(thirdStmt->equals(
expr(assign(field("a"),
binop(field("a"), "+", binop(var("b"), "+", binop(lit(1), "+", var("c"))))))
.get(),
/*compareData = */ false));
auto const& fourthStmt = getNthStmt(getFirstDoMethod(instantiation), 5);
ASSERT_TRUE(fourthStmt->equals(expr(assign(var("c"), binop(var("c"), "+", lit(1)))).get(),
/*compareData = */ false));
auto const& fifthStmt = getNthStmt(getFirstDoMethod(instantiation), 6);
ASSERT_TRUE(fifthStmt->equals(expr(assign(var("b"), binop(var("b"), "-", lit(1)))).get(),
/*compareData = */ false));
auto const& sixthStmt = getNthStmt(getFirstDoMethod(instantiation), 7);
ASSERT_TRUE(sixthStmt->equals(
expr(assign(field("a"), binop(field("a"), "*", binop(var("c"), "*", var("b"))))).get(),
/*compareData = */ false));
}

} // namespace
Loading