Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 39 additions & 39 deletions roofit/codegen/src/CodegenImpl.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@

#include <TInterpreter.h>

namespace RooFit {
namespace Experimental {
namespace RooFit::Experimental {

namespace {

Expand Down Expand Up @@ -103,7 +102,7 @@ void rooHistTranslateImpl(RooAbsArg const &arg, CodegenContext &ctx, int intOrde
}

std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, RooArgList const &funcList,
RooArgList const &coefList, bool normalize)
RooArgList const &coefList, bool normalize, bool forceScopeIndependent)
{
bool noLastCoeff = funcList.size() != coefList.size();

Expand All @@ -113,7 +112,12 @@ std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, R

std::string sum = ctx.getTmpVarName();
std::string coeffSum = ctx.getTmpVarName();
ctx.addToCodeBody(&arg, "double " + sum + " = 0;\ndouble " + coeffSum + "= 0;\n");
std::string code1 = "double " + sum + " = 0;\ndouble " + coeffSum + "= 0;\n";

if (forceScopeIndependent)
ctx.addToCodeBody(code1, true);
else
ctx.addToCodeBody(&arg, code1);

std::string iterator = "i_" + ctx.getTmpVarName();
std::string subscriptExpr = "[" + iterator + "]";
Expand All @@ -128,7 +132,10 @@ std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, R
} else if (normalize) {
code += sum + " /= " + coeffSum + ";\n";
}
ctx.addToCodeBody(&arg, code);
if (forceScopeIndependent)
ctx.addToCodeBody(code, true);
else
ctx.addToCodeBody(&arg, code);

return sum;
}
Expand Down Expand Up @@ -240,7 +247,7 @@ void codegenImpl(RooAbsArg &arg, CodegenContext &ctx)

void codegenImpl(RooAddPdf &arg, CodegenContext &ctx)
{
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.pdfList(), arg.coefList(), true));
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.pdfList(), arg.coefList(), true, false));
}

void codegenImpl(RooMultiVarGaussian &arg, CodegenContext &ctx)
Expand All @@ -261,30 +268,25 @@ void codegenImpl(RooMultiPdf &arg, CodegenContext &ctx)
// indices MathFunc call becomes more efficient.
if (numPdfs > 2) {
ctx.addResult(&arg, ctx.buildCall(mathFunc("multipdf"), arg.indexCategory(), arg.getPdfList()));
return;
}
// Ternary nested expression
std::string indexExpr = ctx.getResult(arg.indexCategory());

std::cout << "MathFunc call used\n";

} else {

// Ternary nested expression
std::string indexExpr = ctx.getResult(arg.indexCategory());

// int numPdfs = arg.getNumPdfs();
std::string expr;
// int numPdfs = arg.getNumPdfs();
std::string expr;

for (int i = 0; i < numPdfs; ++i) {
RooAbsPdf *pdf = arg.getPdf(i);
std::string pdfExpr = ctx.getResult(*pdf);
for (int i = 0; i < numPdfs; ++i) {
RooAbsPdf *pdf = arg.getPdf(i);
std::string pdfExpr = ctx.getResult(*pdf);

expr += "(" + indexExpr + " == " + std::to_string(i) + " ? (" + pdfExpr + ") : ";
}
expr += "(" + indexExpr + " == " + std::to_string(i) + " ? (" + pdfExpr + ") : ";
}

expr += "0.0";
expr += std::string(numPdfs, ')'); // Close all ternary operators
expr += "0.0";
expr += std::string(numPdfs, ')'); // Close all ternary operators

ctx.addResult(&arg, expr);
std::cout << "Ternary expression call used \n";
}
ctx.addResult(&arg, expr);
}

// RooCategory index added.
Expand All @@ -294,7 +296,7 @@ void codegenImpl(RooCategory &arg, CodegenContext &ctx)
if (idx < 0) {

idx = 1;
ctx.addVecObs(arg.GetName(), idx);
ctx.addVecObs(arg.GetName(), idx, 1);
}

std::string result = std::to_string(arg.getCurrentIndex());
Expand All @@ -305,6 +307,7 @@ void codegenImpl(RooAddition &arg, CodegenContext &ctx)
{
if (arg.list().empty()) {
ctx.addResult(&arg, "0.0");
return;
}
std::string result;
if (arg.list().size() > 1)
Expand Down Expand Up @@ -469,7 +472,6 @@ void codegenImpl(RooFit::Detail::RooNLLVarNew &arg, CodegenContext &ctx)

std::string weightSumName = RooFit::Detail::makeValidVarName(arg.GetName()) + "WeightSum";
std::string resName = RooFit::Detail::makeValidVarName(arg.GetName()) + "Result";
ctx.addResult(&arg, resName);
ctx.addToGlobalScope("double " + weightSumName + " = 0.0;\n");
ctx.addToGlobalScope("double " + resName + " = 0.0;\n");

Expand All @@ -496,6 +498,8 @@ void codegenImpl(RooFit::Detail::RooNLLVarNew &arg, CodegenContext &ctx)
std::string expected = ctx.getResult(*arg.expectedEvents());
ctx.addToCodeBody(resName + " += " + expected + " - " + weightSumName + " * std::log(" + expected + ");\n");
}

ctx.addResult(&arg, resName);
}

void codegenImpl(RooFit::Detail::RooNormalizedPdf &arg, CodegenContext &ctx)
Expand Down Expand Up @@ -609,17 +613,17 @@ void codegenImpl(RooRealIntegral &arg, CodegenContext &ctx)
auto &intVar = static_cast<RooAbsRealLValue &>(*arg.numIntRealVars()[0]);

std::string obsName = ctx.getTmpVarName();
std::string oldIntVarResult = ctx.getResult(intVar);
ctx.addResult(&intVar, "obs[0]");

auto oldVecObsInfo = ctx._vecObsIndices[intVar.namePtr()];
ctx.addVecObs(intVar.GetName(), 0, 1);
std::string funcName = ctx.buildFunction(arg.integrand(), {});
ctx._vecObsIndices[intVar.namePtr()] = oldVecObsInfo;

std::stringstream ss;

ss << "double " << obsName << "[1];\n";

std::string resName = RooFit::Detail::makeValidVarName(arg.GetName()) + "Result";
ctx.addResult(&arg, resName);
ctx.addToGlobalScope("double " + resName + " = 0.0;\n");

// TODO: once Clad has support for higher-order functions (follow also the
Expand All @@ -640,24 +644,21 @@ void codegenImpl(RooRealIntegral &arg, CodegenContext &ctx)

ctx.addToGlobalScope(ss.str());

ctx.addResult(&intVar, oldIntVarResult);
ctx.addResult(&arg, resName);
}

void codegenImpl(RooRealSumFunc &arg, CodegenContext &ctx)
{
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.funcList(), arg.coefList(), false));
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.funcList(), arg.coefList(), false, false));
}

void codegenImpl(RooRealSumPdf &arg, CodegenContext &ctx)
{
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.funcList(), arg.coefList(), false));
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.funcList(), arg.coefList(), false, false));
}

void codegenImpl(RooRealVar &arg, CodegenContext &ctx)
{
if (!arg.isConstant()) {
ctx.addResult(&arg, arg.GetName());
}
ctx.addResult(&arg, doubleToString(arg.getVal()));
}

Expand Down Expand Up @@ -898,7 +899,7 @@ std::string codegenIntegralImpl(RooPolynomial &arg, int, const char *rangeName,
std::string codegenIntegralImpl(RooRealSumPdf &arg, int code, const char *rangeName, CodegenContext &ctx)
{
// Re-use translate, since integration is linear.
return realSumPdfTranslateImpl(ctx, arg, arg.funcIntListFromCache(code, rangeName), arg.coefList(), false);
return realSumPdfTranslateImpl(ctx, arg, arg.funcIntListFromCache(code, rangeName), arg.coefList(), false, true);
}

std::string codegenIntegralImpl(RooUniform &arg, int code, const char *rangeName, CodegenContext &)
Expand All @@ -908,5 +909,4 @@ std::string codegenIntegralImpl(RooUniform &arg, int code, const char *rangeName
return doubleToString(arg.analyticalIntegral(code, rangeName));
}

} // namespace Experimental
} // namespace RooFit
} // namespace RooFit::Experimental
29 changes: 13 additions & 16 deletions roofit/roofitcore/inc/RooFit/CodegenContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
template <class T>
class RooTemplateProxy;

namespace RooFit {
namespace Experimental {
namespace RooFit::Experimental {

template <int P>
struct Prio {
Expand All @@ -46,12 +45,11 @@ using PrioLowest = Prio<10>;
class CodegenContext {
public:
void addResult(RooAbsArg const *key, std::string const &value);
void addResult(const char *key, std::string const &value);

std::string const &getResult(RooAbsArg const &arg);
std::string getResult(RooAbsArg const &arg);

template <class T>
std::string const &getResult(RooTemplateProxy<T> const &key)
std::string getResult(RooTemplateProxy<T> const &key)
{
return getResult(key.arg());
}
Expand All @@ -69,7 +67,8 @@ class CodegenContext {
}

void addToGlobalScope(std::string const &str);
void addVecObs(const char *key, int idx);
void addVecObs(const char *key, int idx, std::size_t size);
void addParam(const RooAbsArg *key, int idx);
int observableIndexOf(const RooAbsArg &arg) const;

void addToCodeBody(RooAbsArg const *klass, std::string const &in);
Expand Down Expand Up @@ -135,6 +134,13 @@ class CodegenContext {
};
ScopeRAII OutputScopeRangeComment(RooAbsArg const *arg) { return {arg, *this}; }

/// @brief Map of node names to their result strings.
std::unordered_map<const TNamed *, std::size_t> _nodeNames;
std::size_t _nWksp = 0;
std::unordered_map<const RooAbsArg *, int> _paramIndices;
/// @brief A map to keep track of the observable indices if they are non scalar.
std::unordered_map<const TNamed *, std::pair<int, std::size_t>> _vecObsIndices;

private:
void pushScope();
void popScope();
Expand All @@ -145,8 +151,6 @@ class CodegenContext {

void endLoop(LoopScope const &scope);

void addResult(TNamed const *key, std::string const &value);

template <class T, typename std::enable_if<std::is_floating_point<T>{}, bool>::type = true>
std::string buildArg(T x)
{
Expand Down Expand Up @@ -191,10 +195,6 @@ class CodegenContext {
template <class T>
std::string typeName() const;

/// @brief Map of node names to their result strings.
std::unordered_map<const TNamed *, std::string> _nodeNames;
/// @brief A map to keep track of the observable indices if they are non scalar.
std::unordered_map<const TNamed *, int> _vecObsIndices;
/// @brief Map of node output sizes.
std::map<RooFit::Detail::DataKey, std::size_t> _nodeOutputSizes;
/// @brief The code layered by lexical scopes used as a stack.
Expand All @@ -203,8 +203,6 @@ class CodegenContext {
unsigned _indent = 0;
/// @brief Index to get unique names for temporary variables.
mutable int _tmpVarIdx = 0;
/// @brief A map to keep track of list names as assigned by addResult.
std::unordered_map<RooFit::UniqueId<RooAbsCollection>::Value_t, std::string> _listNames;
std::vector<double> _xlArr;
std::vector<std::string> _collectedFunctions;
};
Expand Down Expand Up @@ -242,7 +240,6 @@ void declareDispatcherCode(std::string const &funcName);

void codegen(RooAbsArg &arg, CodegenContext &ctx);

} // namespace Experimental
} // namespace RooFit
} // namespace RooFit::Experimental

#endif
17 changes: 17 additions & 0 deletions roofit/roofitcore/inc/RooFit/Detail/MathFuncs.h
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,23 @@ double stepFunctionIntegral(double xmin, double xmax, std::size_t nBins, DoubleA

} // namespace RooFit::Detail::MathFuncs

inline void fillFromWorkspace(double *out, std::size_t n, double const *wksp, double const *idx)
{
for (std::size_t i = 0; i < n; ++i) {
out[i] += wksp[static_cast<int>(idx[i])];
}
}

namespace clad::custom_derivatives {

inline void fillFromWorkspace_pullback(double *, std::size_t n, double const *, double const *idx, double *d_out,
std::size_t *, double *d_wksp, double *)
{
for (std::size_t i = 0; i < n; ++i) {
d_wksp[static_cast<int>(idx[i])] += d_out[i];
}
}

namespace RooFit::Detail::MathFuncs {

// Clad can't generate the pullback for binNumber because of the
Expand All @@ -826,6 +842,7 @@ void binNumber_pullback(Types...)
}

} // namespace RooFit::Detail::MathFuncs

} // namespace clad::custom_derivatives

#endif
12 changes: 2 additions & 10 deletions roofit/roofitcore/src/RooEvaluatorWrapper.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -239,21 +239,13 @@ RooFuncWrapper::RooFuncWrapper(RooAbsReal &obj, const RooAbsData *data, RooSimul
// First update the result variable of params in the compute graph to in[<position>].
int idx = 0;
for (RooAbsArg *param : _params) {
ctx.addResult(param, "params[" + std::to_string(idx) + "]");
ctx.addParam(param, idx);
idx++;
}

for (auto const &item : _obsInfos) {
const char *obsName = item.first->GetName();
// If the observable is scalar, set name to the start idx. else, store
// the start idx and later set the the name to obs[start_idx + curr_idx],
// here curr_idx is defined by a loop producing parent node.
if (item.second.size == 1) {
ctx.addResult(obsName, "obs[" + std::to_string(item.second.idx) + "]");
} else {
ctx.addResult(obsName, "obs");
ctx.addVecObs(obsName, item.second.idx);
}
ctx.addVecObs(obsName, item.second.idx, item.second.size);
}

gInterpreter->Declare("#pragma cling optimize(2)");
Expand Down
Loading
Loading