Skip to content

[CAS] Add a new API in ObjectStore to import a CAS tree #10819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: stable/20240723
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/test/CAS/print-compile-job-cache-key.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@
// RUN: cat %t/output-plugin.txt | sed \
// RUN: -e "s/^.*miss for '//" \
// RUN: -e "s/' .*$//" > %t/cache-key-plugin
// RUN: clang-cas-test -print-compile-job-cache-key -cas %t/cas @%t/cache-key-plugin \
// RUN: clang-cas-test -print-compile-job-cache-key -cas %t/cas-plugin @%t/cache-key-plugin \
// RUN: -fcas-plugin-path %llvmshlibdir/libCASPluginTest%pluginext | FileCheck %s
4 changes: 4 additions & 0 deletions llvm/include/llvm/CAS/ObjectStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ class ObjectStore {
/// Validate the whole node tree.
Error validateTree(ObjectRef Ref);

/// Import object from another CAS. This will import the full tree from the
/// other CAS.
Expected<ObjectRef> importObject(ObjectStore &Upstream, ObjectRef Other);

/// Print the ObjectStore internals for debugging purpose.
virtual void print(raw_ostream &) const {}
void dump() const;
Expand Down
87 changes: 87 additions & 0 deletions llvm/lib/CAS/ObjectStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/SmallVectorMemoryBuffer.h"
#include <deque>

using namespace llvm;
using namespace llvm::cas;
Expand Down Expand Up @@ -217,6 +218,92 @@ Error ObjectStore::validateTree(ObjectRef Root) {
return Error::success();
}

Expected<ObjectRef> ObjectStore::importObject(ObjectStore &Upstream,
ObjectRef Other) {
// Copy the full CAS tree from upstream with depth-first ordering to ensure
// all the child nodes are available in downstream CAS before inserting
// current object. This uses a similar algorithm as
// `OnDiskGraphDB::importFullTree` but doesn't assume the upstream CAS schema
// so it can be used to import from any other ObjectStore reguardless of the
// CAS schema.

// There is no work to do if importing from self.
if (this == &Upstream)
return Other;

/// Keeps track of the state of visitation for current node and all of its
/// parents. Upstream Cursor holds information only from upstream CAS.
struct UpstreamCursor {
ObjectRef Ref;
ObjectHandle Node;
size_t RefsCount;
std::deque<ObjectRef> Refs;
};
SmallVector<UpstreamCursor, 16> CursorStack;
/// PrimaryNodeStack holds the ObjectRef of the current CAS, with nodes either
/// just stored in the CAS or nodes already exists in the current CAS.
SmallVector<ObjectRef, 128> PrimaryRefStack;
/// A map from upstream ObjectRef to current ObjectRef.
llvm::DenseMap<ObjectRef, ObjectRef> CreatedObjects;

auto enqueueNode = [&](ObjectRef Ref, ObjectHandle Node) {
unsigned NumRefs = Upstream.getNumRefs(Node);
std::deque<ObjectRef> Refs;
for (unsigned I = 0; I < NumRefs; ++I)
Refs.push_back(Upstream.readRef(Node, I));

CursorStack.push_back({Ref, Node, NumRefs, std::move(Refs)});
};

auto UpstreamHandle = Upstream.load(Other);
if (!UpstreamHandle)
return UpstreamHandle.takeError();
enqueueNode(Other, *UpstreamHandle);

while (!CursorStack.empty()) {
UpstreamCursor &Cur = CursorStack.back();
if (Cur.Refs.empty()) {
// Copy the node data into the primary store.
// The bottom of \p PrimaryRefStack contains the ObjectRef for the
// current node.
assert(PrimaryRefStack.size() >= Cur.RefsCount);
auto Refs = ArrayRef(PrimaryRefStack)
.slice(PrimaryRefStack.size() - Cur.RefsCount);
auto NewNode = store(Refs, Upstream.getData(Cur.Node));
if (!NewNode)
return NewNode.takeError();

// Remove the current node and its IDs from the stack.
PrimaryRefStack.truncate(PrimaryRefStack.size() - Cur.RefsCount);
CursorStack.pop_back();

PrimaryRefStack.push_back(*NewNode);
CreatedObjects.try_emplace(Cur.Ref, *NewNode);
continue;
}

// Check if the node exists already.
auto CurrentID = Cur.Refs.front();
Cur.Refs.pop_front();
auto Ref = CreatedObjects.find(CurrentID);
if (Ref != CreatedObjects.end()) {
// If exists already, just need to enqueue the primary node.
PrimaryRefStack.push_back(Ref->second);
continue;
}

// Load child.
auto PrimaryID = Upstream.load(CurrentID);
if (LLVM_UNLIKELY(!PrimaryID))
return PrimaryID.takeError();

enqueueNode(CurrentID, *PrimaryID);
}

assert(PrimaryRefStack.size() == 1);
return PrimaryRefStack.front();
}

std::unique_ptr<MemoryBuffer>
ObjectProxy::getMemoryBuffer(StringRef Name,
bool RequiresNullTerminator) const {
Expand Down
4 changes: 4 additions & 0 deletions llvm/test/tools/llvm-cas/ingest.test
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ CHECK-ERROR: llvm-cas: get-cas-id: No such file or directory
RUN: llvm-cas --cas %t/cas --ls-node-refs @%t/cas.id 2>&1 | FileCheck %s --check-prefix=CHECK-NODE-REFS
CHECK-NODE-REFS: llvmcas://
CHECK-NODE-REFS: llvmcas://

// Test exporting the entire tree.
RUN: llvm-cas --cas %t/new-cas --fcas-plugin-path %llvmshlibdir/libCASPluginTest%pluginext --upstream-cas %t/cas --import @%t/cas.id > %t/plugin.id
RUN: llvm-cas --cas %t/new-cas --fcas-plugin-path %llvmshlibdir/libCASPluginTest%pluginext --ls-tree-recursive @%t/plugin.id | FileCheck %s
60 changes: 54 additions & 6 deletions llvm/tools/libCASPluginTest/libCASPluginTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
//===----------------------------------------------------------------------===//

#include "llvm-c/CAS/PluginAPI_functions.h"
#include "llvm/CAS/BuiltinCASContext.h"
#include "llvm/CAS/BuiltinObjectHasher.h"
#include "llvm/CAS/CASID.h"
#include "llvm/CAS/UnifiedOnDiskCache.h"
#include "llvm/Support/CBindingWrapping.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ThreadPool.h"
#include "llvm/Support/SHA1.h"

using namespace llvm;
using namespace llvm::cas;
using namespace llvm::cas::builtin;
using namespace llvm::cas::ondisk;

static char *copyNewMallocString(StringRef Str) {
Expand Down Expand Up @@ -125,6 +125,54 @@ bool llcas_cas_options_set_option(llcas_cas_options_t c_opts, const char *name,

namespace {

using HasherT = SHA1;
using HashType = decltype(HasherT::hash(std::declval<ArrayRef<uint8_t> &>()));

class PluginCASContext : public CASContext {
void printIDImpl(raw_ostream &OS, const CASID &ID) const final {
PluginCASContext::printID(ID.getHash(), OS);
}

public:
static StringRef getHashName() { return "SHA1"; }
StringRef getHashSchemaIdentifier() const final {
static const std::string ID =
("llvm.cas.builtin.v2[" + getHashName() + "]").str();
return ID;
}

PluginCASContext() = default;

static Expected<HashType> parseID(StringRef Reference) {
if (!Reference.consume_front("llvmcas://"))
return createStringError(
std::make_error_code(std::errc::invalid_argument),
"invalid cas-id '" + Reference + "'");

if (Reference.size() != 2 * sizeof(HashType))
return createStringError(
std::make_error_code(std::errc::invalid_argument),
"wrong size for cas-id hash '" + Reference + "'");

std::string Binary;
if (!tryGetFromHex(Reference, Binary))
return createStringError(
std::make_error_code(std::errc::invalid_argument),
"invalid hash in cas-id '" + Reference + "'");

assert(Binary.size() == sizeof(HashType));
HashType Digest;
llvm::copy(Binary, Digest.data());
return Digest;
}

static void printID(ArrayRef<uint8_t> Digest, raw_ostream &OS) {
SmallString<64> Hash;
toHex(Digest, /*LowerCase=*/true, Hash);
OS << "llvmcas://" << Hash;
}
};

struct CASWrapper {
std::string FirstPrefix;
std::string SecondPrefix;
Expand Down Expand Up @@ -308,15 +356,15 @@ llcas_cas_t llcas_cas_create(llcas_cas_options_t c_opts, char **error) {
auto &Opts = *unwrap(c_opts);
Expected<std::unique_ptr<UnifiedOnDiskCache>> DB = UnifiedOnDiskCache::open(
Opts.OnDiskPath, /*SizeLimit=*/std::nullopt,
BuiltinCASContext::getHashName(), sizeof(HashType));
PluginCASContext::getHashName(), sizeof(HashType));
if (!DB)
return reportError<llcas_cas_t>(DB.takeError(), error);

std::unique_ptr<UnifiedOnDiskCache> UpstreamDB;
if (!Opts.UpstreamPath.empty()) {
if (Error E = UnifiedOnDiskCache::open(
Opts.UpstreamPath, /*SizeLimit=*/std::nullopt,
BuiltinCASContext::getHashName(), sizeof(HashType))
PluginCASContext::getHashName(), sizeof(HashType))
.moveInto(UpstreamDB))
return reportError<llcas_cas_t>(std::move(E), error);
}
Expand Down Expand Up @@ -380,7 +428,7 @@ unsigned llcas_digest_parse(llcas_cas_t c_cas, const char *printed_digest,
assert(Consumed);
(void)Consumed;

Expected<HashType> Digest = BuiltinCASContext::parseID(PrintedDigest);
Expected<HashType> Digest = PluginCASContext::parseID(PrintedDigest);
if (!Digest)
return reportError(Digest.takeError(), error, 0);
std::uninitialized_copy(Digest->begin(), Digest->end(), bytes);
Expand All @@ -394,7 +442,7 @@ bool llcas_digest_print(llcas_cas_t c_cas, llcas_digest_t c_digest,
raw_svector_ostream OS(PrintDigest);
// Include these for testing purposes.
OS << Wrapper.FirstPrefix << Wrapper.SecondPrefix;
BuiltinCASContext::printID(ArrayRef(c_digest.data, c_digest.size), OS);
PluginCASContext::printID(ArrayRef(c_digest.data, c_digest.size), OS);
*printed_id = copyNewMallocString(PrintDigest);
return false;
}
Expand Down
36 changes: 13 additions & 23 deletions llvm/tools/llvm-cas/llvm-cas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ int main(int Argc, char **Argv) {
if (!UpstreamCAS)
ExitOnErr(createStringError(inconvertibleErrorCode(),
"missing '-upstream-cas'"));
return import(*CAS, *UpstreamCAS, Inputs);

return import(*UpstreamCAS, *CAS, Inputs);
}

if (Command == PutCacheKey || Command == GetCacheResult) {
Expand Down Expand Up @@ -641,32 +642,21 @@ int getCASIDForFile(ObjectStore &CAS, const CASID &ID,
return 0;
}

static ObjectRef importNode(ObjectStore &CAS, ObjectStore &UpstreamCAS,
const CASID &ID) {
ExitOnError ExitOnErr("llvm-cas: import: ");

std::optional<ObjectRef> PrimaryRef = CAS.getReference(ID);
if (PrimaryRef)
return *PrimaryRef; // object is present.

ObjectProxy UpstreamObj = ExitOnErr(UpstreamCAS.getProxy(ID));
SmallVector<ObjectRef> Refs;
ExitOnErr(UpstreamObj.forEachReference([&](ObjectRef UpstreamRef) -> Error {
ObjectRef Ref =
importNode(CAS, UpstreamCAS, UpstreamCAS.getID(UpstreamRef));
Refs.push_back(Ref);
return Error::success();
}));
return ExitOnErr(CAS.storeFromString(Refs, UpstreamObj.getData()));
}

static int import(ObjectStore &CAS, ObjectStore &UpstreamCAS,
static int import(ObjectStore &FromCAS, ObjectStore &ToCAS,
ArrayRef<std::string> Objects) {
ExitOnError ExitOnErr("llvm-cas: import: ");

for (StringRef Object : Objects) {
CASID ID = ExitOnErr(CAS.parseID(Object));
importNode(CAS, UpstreamCAS, ID);
CASID ID = ExitOnErr(FromCAS.parseID(Object));
auto Ref = FromCAS.getReference(ID);
if (!Ref) {
ExitOnErr(createStringError(inconvertibleErrorCode(),
"input not found: " + ID.toString()));
return 1;
}

auto Imported = ExitOnErr(ToCAS.importObject(FromCAS, *Ref));
llvm::outs() << ToCAS.getID(Imported).toString() << "\n";
}
return 0;
}
Expand Down