Skip to content

Commit

Permalink
[df] Allow untyped reading of TTree values
Browse files Browse the repository at this point in the history
  • Loading branch information
vepadulano committed Feb 21, 2025
1 parent d2e7874 commit 14e797e
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 36 deletions.
83 changes: 47 additions & 36 deletions tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "RColumnReaderBase.hxx"
#include <ROOT/RVec.hxx>
#include "ROOT/RDF/Utils.hxx"
#include <Rtypes.h> // Long64_t, R__CLING_PTRCHECK
#include <TTreeReader.h>
#include <TTreeReaderValue.h>
Expand All @@ -22,6 +23,7 @@
#include <array>
#include <memory>
#include <string>
#include <cstddef>

namespace ROOT {
namespace Internal {
Expand All @@ -30,13 +32,13 @@ namespace RDF {
/// RTreeColumnReader specialization for TTree values read via TTreeReaderValues
template <typename T>
class R__CLING_PTRCHECK(off) RTreeColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase {
std::unique_ptr<TTreeReaderValue<T>> fTreeValue;
std::unique_ptr<TTreeReaderUntypedValue> fTreeValue;

void *GetImpl(Long64_t) final { return fTreeValue->Get(); }
public:
/// Construct the RTreeColumnReader. Actual initialization is performed lazily by the Init method.
RTreeColumnReader(TTreeReader &r, const std::string &colName)
: fTreeValue(std::make_unique<TTreeReaderValue<T>>(r, colName.c_str()))
: fTreeValue(std::make_unique<TTreeReaderUntypedValue>(r, colName.c_str(), ROOT::Internal::RDF::TypeID2TypeName(typeid(T))))
{
}
};
Expand All @@ -59,16 +61,21 @@ public:
/// TTreeReaderArrays are used whenever the RDF column type is RVec<T>.
template <typename T>
class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<T>> final : public ROOT::Detail::RDF::RColumnReaderBase {
std::unique_ptr<TTreeReaderArray<T>> fTreeArray;
std::unique_ptr<TTreeReaderUntypedArray> fTreeArray;

using Byte_t = std::byte;

/// We return a reference to this RVec to clients, to guarantee a stable address and contiguous memory layout.
RVec<T> fRVec;
RVec<Byte_t> fRVec;

Long64_t fLastEntry = -1;

/// Whether we already printed a warning about performing a copy of the TTreeReaderArray contents
bool fCopyWarningPrinted = false;

/// The size of the collection value type.
std::size_t fValueSize{};

void *GetImpl(Long64_t entry) final
{
if (entry == fLastEntry)
Expand All @@ -86,11 +93,10 @@ class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<T>> final : public ROOT::Det
// trigger loading of the contents of the TTreeReaderArray
// the address of the first element in the reader array is not necessarily equal to
// the address returned by the GetAddress method
auto readerArrayAddr = &readerArray.At(0);
RVec<T> rvec(readerArrayAddr, readerArraySize);
RVec<Byte_t> rvec(readerArray.At(0), readerArraySize);
swap(fRVec, rvec);
} else {
RVec<T> emptyVec{};
RVec<Byte_t> emptyVec{};
swap(fRVec, emptyVec);
}
} else {
Expand All @@ -107,10 +113,21 @@ class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<T>> final : public ROOT::Det
(void)fCopyWarningPrinted;
#endif
if (readerArraySize > 0) {
RVec<T> rvec(readerArray.begin(), readerArray.end());
swap(fRVec, rvec);
// Caching the value type size since GetValueSize might be expensive.
if (fValueSize == 0)
fValueSize = readerArray.GetValueSize();
assert(fValueSize > 0 && "Could not retrieve size of collection value type.");
// Array is not contiguous, make a full copy of it.
fRVec = RVec<Byte_t>();
fRVec.reserve(readerArraySize * fValueSize);
for (std::size_t i{0}; i < readerArraySize; i++)
{
auto val = readerArray.At(i);
std::copy(val, val + fValueSize, std::back_inserter(fRVec));
}
fRVec.resize(readerArraySize);
} else {
RVec<T> emptyVec{};
RVec<Byte_t> emptyVec{};
swap(fRVec, emptyVec);
}
}
Expand All @@ -120,7 +137,7 @@ class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<T>> final : public ROOT::Det

public:
RTreeColumnReader(TTreeReader &r, const std::string &colName)
: fTreeArray(std::make_unique<TTreeReaderArray<T>>(r, colName.c_str()))
: fTreeArray(std::make_unique<TTreeReaderUntypedArray>(r, colName, ROOT::Internal::RDF::TypeID2TypeName(typeid(T))))
{
}
};
Expand All @@ -131,10 +148,12 @@ public:
template <>
class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<bool>> final : public ROOT::Detail::RDF::RColumnReaderBase {

std::unique_ptr<TTreeReaderArray<bool>> fTreeArray;
using Byte_t = std::byte;

std::unique_ptr<TTreeReaderUntypedArray> fTreeArray;

/// We return a reference to this RVec to clients, to guarantee a stable address and contiguous memory layout
RVec<bool> fRVec;
RVec<Byte_t> fRVec;

// We always copy the contents of TTreeReaderArray<bool> into an RVec<bool> (never take a view into the memory
// buffer) because the underlying memory buffer might be the one of a std::vector<bool>, which is not a contiguous
Expand All @@ -146,19 +165,25 @@ class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<bool>> final : public ROOT::
auto &readerArray = *fTreeArray;
const auto readerArraySize = readerArray.GetSize();
if (readerArraySize > 0) {
// always perform a copy
RVec<bool> rvec(readerArray.begin(), readerArray.end());
swap(fRVec, rvec);
// Always perform a copy
fRVec = RVec<Byte_t>();
fRVec.reserve(readerArraySize * sizeof(bool));
for (std::size_t i{0}; i < readerArraySize; i++)
{
auto val = readerArray.At(i);
std::copy(val, val + sizeof(bool), std::back_inserter(fRVec));
}
fRVec.resize(readerArraySize);
} else {
RVec<bool> emptyVec{};
RVec<Byte_t> emptyVec{};
swap(fRVec, emptyVec);
}
return &fRVec;
}

public:
RTreeColumnReader(TTreeReader &r, const std::string &colName)
: fTreeArray(std::make_unique<TTreeReaderArray<bool>>(r, colName.c_str()))
: fTreeArray(std::make_unique<TTreeReaderUntypedArray>(r, colName.c_str(), ROOT::Internal::RDF::TypeID2TypeName(typeid(bool))))
{
}
};
Expand All @@ -168,32 +193,18 @@ public:
/// This specialization is used when the requested type for reading is std::array
template <typename T, std::size_t N>
class R__CLING_PTRCHECK(off) RTreeColumnReader<std::array<T, N>> final : public ROOT::Detail::RDF::RColumnReaderBase {
std::unique_ptr<TTreeReaderArray<T>> fTreeArray;

/// We return a reference to this RVec to clients, to guarantee a stable address and contiguous memory layout
RVec<T> fArray;
std::unique_ptr<TTreeReaderUntypedArray> fTreeArray;

Long64_t fLastEntry = -1;

void *GetImpl(Long64_t entry) final
void *GetImpl(Long64_t) final
{
if (entry == fLastEntry)
return fArray.data();

// This is a non-owning view on the contents of the TTreeReaderArray
RVec<T> view{&fTreeArray->At(0), fTreeArray->GetSize()};
swap(fArray, view);

fLastEntry = entry;
// The data member of this class is an RVec, to avoid an extra copy
// but we need to return the array buffer as the reader expects
// a std::array
return fArray.data();
return fTreeArray->At(0);
}

public:
RTreeColumnReader(TTreeReader &r, const std::string &colName)
: fTreeArray(std::make_unique<TTreeReaderArray<T>>(r, colName.c_str()))
: fTreeArray(std::make_unique<TTreeReaderUntypedArray>(r, colName.c_str(), ROOT::Internal::RDF::TypeID2TypeName(typeid(T))))
{
}
};
Expand Down
2 changes: 2 additions & 0 deletions tree/treeplayer/inc/TBranchProxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,8 @@ namespace Detail {
Int_t GetOffset() { return fOffset; }

bool GetSuppressErrorsForMissingBranch() const { return fSuppressMissingBranchError; }

Int_t GetStreamerElementSize() const;
};
} // namespace Detail

Expand Down
21 changes: 21 additions & 0 deletions tree/treeplayer/inc/TTreeReaderArray.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// @(#)root/tree:$Id$
// Author: Axel Naumann, 2010-08-02
// Author: Vincenzo Eduardo Padulano CERN 02/2025

/*************************************************************************
* Copyright (C) 1995-2013, Rene Brun and Fons Rademakers. *
Expand All @@ -15,6 +16,7 @@
#include "TTreeReaderValue.h"
#include "TTreeReaderUtils.h"
#include <type_traits>
#include <cstddef>

namespace ROOT {
namespace Internal {
Expand All @@ -37,6 +39,9 @@ class TTreeReaderArrayBase : public TTreeReaderValueBase {

bool IsContiguous() const { return fImpl->IsContiguous(GetProxy()); }

/// Returns the `sizeof` of the collection value type. Returns 0 in case the value size could not be retrieved.
std::size_t GetValueSize() const { return fImpl ? fImpl->GetValueSize(GetProxy()): 0; }

protected:
void *UntypedAt(std::size_t idx) const { return fImpl->At(GetProxy(), idx); }
void CreateProxy() override;
Expand All @@ -51,6 +56,22 @@ class TTreeReaderArrayBase : public TTreeReaderValueBase {
// ClassDefOverride(TTreeReaderArrayBase, 0);//Accessor to member of an object stored in a collection
};

class R__CLING_PTRCHECK(off) TTreeReaderUntypedArray final : public TTreeReaderArrayBase {
std::string fArrayElementTypeName;

public:
TTreeReaderUntypedArray(TTreeReader &tr, std::string_view branchName, std::string_view innerTypeName)
: TTreeReaderArrayBase(&tr, branchName.data(), TDictionary::GetDictionary(innerTypeName.data())),
fArrayElementTypeName(innerTypeName)
{
}

std::byte *At(std::size_t idx) const { return reinterpret_cast<std::byte *>(UntypedAt(idx)); }

protected:
const char *GetDerivedTypeName() const final { return fArrayElementTypeName.c_str(); }
};

} // namespace Internal
} // namespace ROOT

Expand Down
1 change: 1 addition & 0 deletions tree/treeplayer/inc/TTreeReaderUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ namespace Internal {
virtual size_t GetSize(Detail::TBranchProxy*) = 0;
virtual void* At(Detail::TBranchProxy*, size_t /*idx*/) = 0;
virtual bool IsContiguous(Detail::TBranchProxy *) = 0;
virtual std::size_t GetValueSize(Detail::TBranchProxy *) = 0;
};

}
Expand Down
24 changes: 24 additions & 0 deletions tree/treeplayer/inc/TTreeReaderValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,30 @@ class R__CLING_PTRCHECK(off) TTreeReaderOpaqueValue final : public ROOT::Interna
const char *GetDerivedTypeName() const { return ""; }
};

class R__CLING_PTRCHECK(off) TTreeReaderUntypedValue final : public TTreeReaderValueBase {
std::string fElementTypeName;

public:
TTreeReaderUntypedValue(TTreeReader &tr, std::string_view branchName, std::string_view typeName)
: TTreeReaderValueBase(&tr, branchName.data(), TDictionary::GetDictionary(typeName.data())),
fElementTypeName(typeName)
{
}

void *Get()
{
if (!fProxy) {
ErrorAboutMissingProxyIfNeeded();
return nullptr;
}
void *address = GetAddress(); // Needed to figure out if it's a pointer
return fProxy->IsaPointer() ? *(void **)address : (void *)address;
}

protected:
const char *GetDerivedTypeName() const final { return fElementTypeName.c_str(); }
};

} // namespace Internal
} // namespace ROOT

Expand Down
5 changes: 5 additions & 0 deletions tree/treeplayer/src/TBranchProxy.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,8 @@ bool ROOT::Detail::TBranchProxy::Setup()
return false;
}
}

Int_t ROOT::Detail::TBranchProxy::GetStreamerElementSize() const
{
return fElement ? fElement->GetSize() : 0;
}
50 changes: 50 additions & 0 deletions tree/treeplayer/src/TTreeReaderArray.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include <memory>
#include <optional>
#include <iostream>

// pin vtable
ROOT::Internal::TVirtualCollectionReader::~TVirtualCollectionReader() {}
Expand Down Expand Up @@ -72,6 +73,11 @@ class TClonesReader : public TVirtualCollectionReader {
}

bool IsContiguous(ROOT::Detail::TBranchProxy *) override { return false; }

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *proxy) override {
auto *ca = GetCA(proxy);
return ca ? ca->GetClass()->Size() : 0;
}
};

bool IsCPContiguous(const TVirtualCollectionProxy &cp)
Expand All @@ -86,6 +92,14 @@ bool IsCPContiguous(const TVirtualCollectionProxy &cp)
}
}

UInt_t GetCPValueSize(const TVirtualCollectionProxy &cp)
{
// This works only if the collection proxy value type is a fundamental type
auto &&eDataType = cp.GetType();
auto *tDataType = TDataType::GetDataType(eDataType);
return tDataType ? tDataType->Size() : 0;
}

// Reader interface for STL
class TSTLReader final : public TVirtualCollectionReader {
public:
Expand Down Expand Up @@ -131,6 +145,12 @@ class TSTLReader final : public TVirtualCollectionReader {
auto cp = GetCP(proxy);
return IsCPContiguous(*cp);
}

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *proxy) override
{
auto cp = GetCP(proxy);
return GetCPValueSize(*cp);
}
};

class TCollectionLessSTLReader final : public TVirtualCollectionReader {
Expand Down Expand Up @@ -190,6 +210,12 @@ class TCollectionLessSTLReader final : public TVirtualCollectionReader {
auto cp = GetCP(proxy);
return IsCPContiguous(*cp);
}

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *proxy) override
{
auto cp = GetCP(proxy);
return GetCPValueSize(*cp);
}
};

// Reader interface for leaf list
Expand Down Expand Up @@ -243,6 +269,12 @@ class TObjectArrayReader : public TVirtualCollectionReader {
void SetBasicTypeSize(Int_t size) { fBasicTypeSize = size; }

bool IsContiguous(ROOT::Detail::TBranchProxy *) override { return true; }

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *proxy) override
{
auto cp = GetCP(proxy);
return GetCPValueSize(*cp);
}
};

template <class BASE>
Expand Down Expand Up @@ -387,6 +419,18 @@ class TBasicTypeArrayReader final : public TVirtualCollectionReader {
}

bool IsContiguous(ROOT::Detail::TBranchProxy *) override { return false; }

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *proxy) override
{
if (!proxy->Read()) {
fReadStatus = TTreeReaderValueBase::kReadError;
if (!proxy->GetSuppressErrorsForMissingBranch())
Error("TBasicTypeArrayReader::GetValueSize()", "Read error in TBranchProxy.");
return 0;
}
fReadStatus = TTreeReaderValueBase::kReadSuccess;
return proxy->GetStreamerElementSize();
}
};

class TBasicTypeClonesReader final : public TClonesReader {
Expand Down Expand Up @@ -434,6 +478,12 @@ class TLeafReader : public TVirtualCollectionReader {

bool IsContiguous(ROOT::Detail::TBranchProxy *) override { return true; }

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *) override
{
auto *leaf = fValueReader->GetLeaf();
return leaf ? leaf->GetLenType(): 0;
}

protected:
void ProxyRead() { fValueReader->ProxyRead(); }
};
Expand Down

0 comments on commit 14e797e

Please sign in to comment.