Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions isl/interface/scala.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,20 @@ std::string isl_class_to_scala(const std::string &s) {
}
return to_camel_case(s.substr(4));
}

std::string scala_generator::prototype_to_scala(const FunctionProtoType *ft) {
std::string name = "(";
for(int i = 0; i < ft->getNumParams(); i++) {
if(i != 0)
name += ", ";
const auto &arg = ft->getParamType(i);
name += isl_type_to_scala(arg);
}
name += ") => ";
name += isl_type_to_scala(ft->getReturnType());
return name;
}

std::string scala_generator::prototype_to_jni(const FunctionProtoType *ft) {
std::string name = isl_type_to_scala(ft->getReturnType());
for (const auto &arg : ft->param_types()) {
name += isl_type_to_scala(arg);
Expand Down Expand Up @@ -240,7 +252,9 @@ std::string scala_generator::isl_type_to_scala(const QualType &type, const bool
return "AbstractReference[" + isl_type_to_scala(t) + "]";
if(t->isFunctionProtoType()) {
auto ft = t->getAs<FunctionProtoType>();
callback_types[prototype_to_scala(ft)] = const_cast<FunctionProtoType*>(ft);
callback_types[prototype_to_jni(ft)] = const_cast<FunctionProtoType*>(ft);
if (for_jni)
return prototype_to_jni(ft);
return prototype_to_scala(ft);
}
if (for_jni)
Expand Down Expand Up @@ -488,9 +502,9 @@ void scala_generator::generate()
}

for(auto const & c : callback_types) {
os << "trait " << c.first << ":"<< std::endl;
os << "private[isl] trait " << c.first << ":"<< std::endl;
os << " @Delegate" << std::endl;
os << " def apply";
os << " def JNIApply";
unsigned i = 0;
os << "(";

Expand All @@ -501,6 +515,42 @@ void scala_generator::generate()
}
os << ")";
os << ": " << isl_type_to_scala(c.second->getReturnType(), true) << std::endl;

os << "object " << c.first << ":" << std::endl;
os << " private[isl] given Conversion[" << c.first << ", " << prototype_to_scala(c.second) << "] = ";
os << "f => (";
for (int i = 0; i < c.second->getNumParams(); i++) {
auto p = c.second->getParamType(i);
if (i != 0)
os << ", ";
os << "arg" << (i + 1) << ": " << isl_type_to_scala(p);
}
os << ") => f.JNIApply(";
for (int i = 0; i < c.second->getNumParams(); i++) {
if (i != 0)
os << ", ";
os << "arg" << (i + 1);
}
os << ")" << std::endl;os << " private[isl] given Conversion[" << prototype_to_scala(c.second) << ", " << c.first << "] = ";
os << "f => (";
for (int i = 0; i < c.second->getNumParams(); i++) {
auto p = c.second->getParamType(i);
if (i != 0)
os << ", ";
os << "arg" << (i + 1) << ": " << isl_type_to_scala(p, true);
}
os << ") => f(";
for (int i = 0; i < c.second->getNumParams(); i++) {
auto p = c.second->getParamType(i);
if (i != 0)
os << ", ";
auto scala_type = isl_type_to_scala(p, true);
if(scala_type == "Pointer" || scala_type == "Int")
os << "arg" << (i + 1);
else
os << "new " << scala_type << "(arg" << (i + 1) << ")";
}
os << ")" << std::endl;
}

os.flush();
Expand Down
1 change: 1 addition & 0 deletions isl/interface/scala.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class scala_generator : public generator {
void printLibraryCall(std::ostream &os, const FunctionDecl *f, bool as_method);
void printId(std::ostream &os, const std::string& id);
std::string prototype_to_scala(const FunctionProtoType *ft);
std::string prototype_to_jni(const FunctionProtoType *ft);
};

#endif /* ISL_INTERFACE_SCALA_H */
6 changes: 5 additions & 1 deletion src/main/scala-3/ISL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@ object Example:
given ctx: Ctx = Ctx()
val str = "{ [i, j] : 0 <= i <= 10 and 0 <= j <= 10 }"
val basicSet = BasicSet(str)
println(basicSet.samplePoint())
basicSet.toSet().foreachPoint((p, u) => {
println(p)
0
},
null)