Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ function run_torch_xla_cpp_tests() {
TORCH_DIR=$(python -c "import pkgutil; import os; print(os.path.dirname(pkgutil.get_loader('torch').get_filename()))")
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${TORCH_DIR}/lib
export PJRT_DEVICE=CPU
export CPU_NUM_DEVICES=2
export XLA_EXPERIMENTAL="nonzero:masked_select:nms"

test_names=("test_aten_xla_tensor_1"
Expand Down
12 changes: 12 additions & 0 deletions test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ if [[ "$BAZEL_VERB" == "coverage" ]]; then
EXTRA_FLAGS="$EXTRA_FLAGS --remote_download_outputs=all" # for lcov symlink
fi

# Forward PJRT_DEVICE and CPU_NUM_DEVICES to bazel test environment.
# Set sensible defaults when not provided so tests run reproducibly.
: "${PJRT_DEVICE:=CPU}"
: "${CPU_NUM_DEVICES:=2}"
export PJRT_DEVICE CPU_NUM_DEVICES
if [[ -n "${PJRT_DEVICE}" ]]; then
EXTRA_FLAGS="$EXTRA_FLAGS --test_env=PJRT_DEVICE=${PJRT_DEVICE}"
fi
if [[ -n "${CPU_NUM_DEVICES}" ]]; then
EXTRA_FLAGS="$EXTRA_FLAGS --test_env=CPU_NUM_DEVICES=${CPU_NUM_DEVICES}"
fi

test_names=("all")
if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then
test_names=("test_aten_xla_tensor_1"
Expand Down
122 changes: 121 additions & 1 deletion test/cpp/test_xla_generator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <torch/torch.h>

#include <cstdlib>

#include "test/cpp/torch_xla_test.h"
#include "torch_xla/csrc/xla_generator.h"

Expand Down Expand Up @@ -102,5 +105,122 @@ TEST_F(XLAGeneratorTest, Clone) {
ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed());
}

TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) {
// Test getting default generator for device 0
auto result = at::detail::GetDefaultXLAGenerator(0);
ASSERT_TRUE(result.ok()) << "Failed to get default generator: "
<< result.status();

const at::Generator& default_gen = result.value();
ASSERT_EQ(default_gen.device().type(), at::DeviceType::XLA);
ASSERT_EQ(default_gen.device().index(), 0);

// Test getting default generator with -1 (should default to device 0)
auto result_default = at::detail::GetDefaultXLAGenerator(-1);
ASSERT_TRUE(result_default.ok())
<< "Failed to get default generator with -1: " << result_default.status();

const at::Generator& default_gen_neg1 = result_default.value();
ASSERT_EQ(default_gen_neg1.device().type(), at::DeviceType::XLA);
ASSERT_EQ(default_gen_neg1.device().index(), 0);
ASSERT_EQ(default_gen, default_gen_neg1);

// Test that subsequent calls return the same generator instance
auto result2 = at::detail::GetDefaultXLAGenerator(0);
ASSERT_TRUE(result2.ok());
const at::Generator& default_gen2 = result2.value();
ASSERT_EQ(default_gen, default_gen2);

// Test getting non-defuault device generator
auto result_device1 = at::detail::GetDefaultXLAGenerator(1);
ASSERT_TRUE(result_device1.ok())
<< "Failed to get default generator for device 1: "
<< result_device1.status();

const at::Generator& default_gen_device1 = result_device1.value();
ASSERT_EQ(default_gen_device1.device().type(), at::DeviceType::XLA);
ASSERT_EQ(default_gen_device1.device().index(), 1);
ASSERT_NE(default_gen_device1, default_gen);
}

TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) {
// Test with invalid device indices
auto result_neg2 = at::detail::GetDefaultXLAGenerator(-2);
ASSERT_FALSE(result_neg2.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status()));
ASSERT_THAT(result_neg2.status().message(),
testing::HasSubstr("Invalid XLA device index"));

// Test with very large device index (assuming there aren't 1000 XLA devices)
auto result_large = at::detail::GetDefaultXLAGenerator(100);
ASSERT_FALSE(result_large.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_large.status()));
ASSERT_THAT(result_large.status().message(),
testing::HasSubstr("Invalid XLA device index"));
}

TEST_F(XLAGeneratorTest, CreateXLAGenerator) {
// Test creating generator for device 1
auto result = at::detail::CreateXLAGenerator(1);
ASSERT_TRUE(result.ok()) << "Failed to create generator: " << result.status();

at::Generator created_gen = result.value();
ASSERT_EQ(created_gen.device().type(), at::DeviceType::XLA);
ASSERT_EQ(created_gen.device().index(), 1);

// Test that the generator is initialized with default seed
ASSERT_EQ(created_gen.current_seed(), c10::default_rng_seed_val);

// Test creating generator with -1 (should use current device)
auto result_current = at::detail::CreateXLAGenerator(-1);
ASSERT_TRUE(result_current.ok())
<< "Failed to create generator with -1: " << result_current.status();

at::Generator created_gen_neg1 = result_current.value();
ASSERT_EQ(created_gen_neg1.device().type(), at::DeviceType::XLA);
// Device index should be >= 0 (actual device depends on current XLA device)
ASSERT_GE(created_gen_neg1.device().index(), 0);
}

TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) {
// Test that each call creates a new generator instance
auto result1 = at::detail::CreateXLAGenerator(0);
auto result2 = at::detail::CreateXLAGenerator(0);

ASSERT_TRUE(result1.ok());
ASSERT_TRUE(result2.ok());

at::Generator gen1 = result1.value();
at::Generator gen2 = result2.value();

// Should be different instances (compare generators, not their stack
// addresses)
ASSERT_NE(gen1, gen2);

// But should have same device and initial seed
ASSERT_EQ(gen1.device(), gen2.device());
ASSERT_EQ(gen1.current_seed(), gen2.current_seed());

// Modifying one should not affect the other
gen1.set_current_seed(12345);
ASSERT_NE(gen1.current_seed(), gen2.current_seed());
}

TEST_F(XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) {
// Test with invalid device indices
auto result_neg2 = at::detail::CreateXLAGenerator(-2);
ASSERT_FALSE(result_neg2.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status()));
ASSERT_THAT(result_neg2.status().message(),
testing::HasSubstr("Invalid XLA device index"));

// Test with very large device index (assuming there aren't 100 XLA devices)
auto result_large = at::detail::CreateXLAGenerator(100);
ASSERT_FALSE(result_large.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_large.status()));
ASSERT_THAT(result_large.status().message(),
testing::HasSubstr("Invalid XLA device index"));
}

} // namespace cpp_test
} // namespace torch_xla
} // namespace torch_xla
103 changes: 103 additions & 0 deletions torch_xla/csrc/xla_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,113 @@
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/CallOnce.h>
#include <c10/util/intrusive_ptr.h>

#include <cstring>
#include <deque>
#include <vector>

#include "absl/status/status.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/status.h"

namespace at {

namespace detail {

namespace {

// Total number of XLA devices in the system.
static int64_t num_xla_devices;

// Ensures default_gens_xla is initialized once.
static std::deque<c10::once_flag> xla_gens_init_flag;

// Default, global XLA generators, one per XLA device.
static std::vector<at::Generator> default_gens_xla;

/*
* Populates the global variables related to XLA generators
* Warning: this function must only be called once!
*/
static absl::Status InitGlobalVars() {
static const absl::Status* init_status = new absl::Status([]() {
XLA_ASSIGN_OR_RETURN(auto c_client,
torch_xla::runtime::GetComputationClient());
num_xla_devices = static_cast<int64_t>(c_client->GetNumDevices());
xla_gens_init_flag.resize(num_xla_devices);
default_gens_xla.resize(num_xla_devices);
return absl::OkStatus();
}());
return *init_status;
}

// Validates and normalizes an XLA device index.
// If requested_index == -1, the current device index is used.
// Returns InvalidArgument if the resolved index is out of range.
static absl::StatusOr<c10::DeviceIndex> NormalizeXLADeviceIndex(
c10::DeviceIndex requested_index) {
c10::DeviceIndex idx = requested_index;
if (idx == -1) {
idx = torch_xla::bridge::GetCurrentAtenDevice().index();
}
if (idx < 0 || idx >= num_xla_devices) {
return absl::InvalidArgumentError(
"Invalid device index for XLA generator. Provided index: " +
std::to_string(idx));
}
return idx;
}

} // anonymous namespace

/**
* PyTorch maintains a collection of default generators that get
* initialized once. The purpose of these default generators is to
* maintain a global running state of the pseudo random number generation,
* when a user does not explicitly mention any generator.
* GetDefaultXLAGenerator gets the default generator for a particular
* XLA device.
*/
absl::StatusOr<const at::Generator&> GetDefaultXLAGenerator(
c10::DeviceIndex device_index) {
XLA_RETURN_IF_ERROR(InitGlobalVars(), "Failed to initialize XLA generators");
// Normalize and validate the target device index; default to current device
// when unspecified
XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx,
NormalizeXLADeviceIndex(device_index),
"Invalid XLA device index");
c10::call_once(xla_gens_init_flag[idx], [&] {
default_gens_xla[idx] = at::make_generator<XLAGeneratorImpl>(idx);
default_gens_xla[idx].seed();
});
return default_gens_xla[idx];
}

/**
* Utility to create a XLAGeneratorImpl. Returns a shared_ptr
*/
absl::StatusOr<at::Generator> CreateXLAGenerator(
c10::DeviceIndex device_index) {
XLA_RETURN_IF_ERROR(InitGlobalVars(), "Failed to initialize XLA generators");
// Normalize and validate the target device index; default to current device
// when unspecified
XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx,
NormalizeXLADeviceIndex(device_index),
"Invalid XLA device index");
auto gen = at::make_generator<XLAGeneratorImpl>(idx);
auto xla_gen = at::check_generator<XLAGeneratorImpl>(gen);
xla_gen->set_current_seed(c10::default_rng_seed_val);
return gen;
}

} // namespace detail
} // namespace at

namespace at {

Expand Down
18 changes: 17 additions & 1 deletion torch_xla/csrc/xla_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@

#include <ATen/core/Generator.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/intrusive_ptr.h>

#include <cstdint>

#include "absl/status/status.h"
#include "absl/status/statusor.h"

namespace at {

// Holds the actual state variables for the XLA generator.
Expand Down Expand Up @@ -53,4 +60,13 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl {
c10::intrusive_ptr<XLAGeneratorState> state_;
};

} // namespace at
namespace detail {

absl::StatusOr<const at::Generator&> GetDefaultXLAGenerator(
c10::DeviceIndex device_index = -1);
absl::StatusOr<at::Generator> CreateXLAGenerator(
c10::DeviceIndex device_index = -1);

} // namespace detail

} // namespace at