Skip to content

Commit 06168bf

Browse files
pbelevichfacebook-github-bot
authored andcommitted
Move geometric_() to DistributionTemplates (pytorch#37418)
Summary: Pull Request resolved: pytorch#37418 Fixes pytorch#37369 Test Plan: Imported from OSS Differential Revision: D21290757 Pulled By: pbelevich fbshipit-source-id: 42133f35edcbe716a07987bef2e68a4cdc27236a
1 parent ce6077d commit 06168bf

8 files changed

+111
-42
lines changed

aten/src/ATen/native/DistributionTemplates.h

+10
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,16 @@ at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, c10::opt
298298
return self;
299299
}
300300

301+
// =================================================== Geometric ======================================================
302+
303+
template<template<typename> class geometric_kernel, typename RNG>
304+
Tensor& geometric_impl_(Tensor& self, double p, c10::optional<Generator> gen) {
305+
TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
306+
auto iter = TensorIterator::nullary_op(self);
307+
geometric_kernel<RNG>()(iter, p, gen);
308+
return self;
309+
}
310+
301311
#undef CHECK_OUT_OF_BOUNDS_AND_SHOW_WARNING
302312

303313
}}}

aten/src/ATen/native/Distributions.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,17 @@ Tensor& exponential_(Tensor& self, double lambda, c10::optional<Generator> gen)
221221
return self;
222222
}
223223

224+
// =================================================== Geometric ======================================================
225+
226+
template<typename RNG>
227+
struct GeometricStub {
228+
void operator()(TensorIterator& iter, double p, c10::optional<Generator> gen) {
229+
geometric_stub(iter.device_type(), iter, p, gen);
230+
}
231+
};
232+
224233
Tensor& geometric_(Tensor& self, double p, c10::optional<Generator> gen) {
225-
TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
226-
auto iter = TensorIterator::nullary_op(self);
227-
geometric_stub(iter.device_type(), iter, p, gen);
228-
return self;
234+
return at::native::templates::geometric_impl_<GeometricStub, Generator>(self, p, gen);
229235
}
230236

231237
// ==================================================== Uniform =======================================================

aten/src/ATen/native/cpu/DistributionTemplates.h

+20
Original file line numberDiff line numberDiff line change
@@ -256,4 +256,24 @@ struct LogNormalKernel {
256256
}
257257
};
258258

259+
// =================================================== Geometric ======================================================
260+
261+
template<typename RNG>
262+
void geometric_kernel(TensorIterator& iter, double p, RNG generator) {
263+
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "geometric_cpu", [&]() {
264+
std::lock_guard<std::mutex> lock(generator->mutex_);
265+
cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
266+
at::geometric_distribution<double> geometric(p);
267+
return (scalar_t)geometric(generator);
268+
});
269+
});
270+
}
271+
272+
template<typename RNG>
273+
struct GeometricKernel {
274+
void operator()(TensorIterator& iter, double p, c10::optional<Generator> gen) {
275+
geometric_kernel(iter, p, check_generator<RNG>(gen));
276+
}
277+
};
278+
259279
}}}}}

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

+2-8
Original file line numberDiff line numberDiff line change
@@ -329,14 +329,8 @@ static void exponential_kernel(TensorIterator& iter, double lambda, c10::optiona
329329
}
330330

331331
static void geometric_kernel(TensorIterator& iter, double p, c10::optional<Generator> gen) {
332-
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "geometric_cpu", [&]() {
333-
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
334-
std::lock_guard<std::mutex> lock(generator->mutex_);
335-
cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
336-
at::geometric_distribution<double> geometric(p);
337-
return (scalar_t)geometric(generator);
338-
});
339-
});
332+
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
333+
templates::cpu::geometric_kernel(iter, p, generator);
340334
}
341335

342336
static void log_normal_kernel(TensorIterator& iter, double mean, double std, c10::optional<Generator> gen) {

aten/src/ATen/native/cuda/DistributionGeometricKernel.cu

+4-25
Original file line numberDiff line numberDiff line change
@@ -29,32 +29,11 @@
2929

3030
namespace at { namespace native {
3131

32-
void geometric_kernel_cuda(TensorIterator& iter, double p_, c10::optional<Generator> gen_) {
33-
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
34-
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
35-
if (std::is_same<scalar_t, double>::value) {
36-
// define lambda for geometric transformation
37-
auto geometric_func = [p_] __device__ (double rand) {
38-
return static_cast<scalar_t>(::ceil(::log(rand) / ::log(static_cast<double>(1.0)-p_)));
39-
};
40-
distribution_nullary_kernel<scalar_t, double, curand4_engine_calls/2>(iter,
41-
gen,
42-
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
43-
geometric_func);
44-
} else {
45-
auto p = static_cast<float>(p_);
46-
auto geometric_func = [p] __device__ (float rand) {
47-
// use __logf fast approximation for peak bandwidth
48-
return static_cast<scalar_t>(::ceil(__logf(rand) / __logf(static_cast<float>(1.0)-p)));
49-
};
50-
distribution_nullary_kernel<scalar_t, float, curand4_engine_calls>(iter,
51-
gen,
52-
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
53-
geometric_func);
54-
}
55-
});
32+
void geometric_kernel(TensorIterator& iter, double p_, c10::optional<Generator> gen) {
33+
auto generator = get_generator_or_default<CUDAGeneratorImpl>(gen, cuda::detail::getDefaultCUDAGenerator());
34+
at::native::templates::cuda::geometric_kernel(iter, p_, generator);
5635
}
5736

58-
REGISTER_DISPATCH(geometric_stub, &geometric_kernel_cuda);
37+
REGISTER_DISPATCH(geometric_stub, &geometric_kernel);
5938

6039
}} // namespace at::native

aten/src/ATen/native/cuda/DistributionTemplates.h

+35
Original file line numberDiff line numberDiff line change
@@ -518,4 +518,39 @@ struct LogNormalKernel {
518518
}
519519
};
520520

521+
// =================================================== Geometric ======================================================
522+
523+
template<typename RNG>
524+
void geometric_kernel(TensorIterator& iter, double p_, RNG gen) {
525+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
526+
if (std::is_same<scalar_t, double>::value) {
527+
// define lambda for geometric transformation
528+
auto geometric_func = [p_] __device__ (double rand) {
529+
return static_cast<scalar_t>(::ceil(::log(rand) / ::log(static_cast<double>(1.0)-p_)));
530+
};
531+
distribution_nullary_kernel<scalar_t, double, curand4_engine_calls/2>(iter,
532+
gen,
533+
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
534+
geometric_func);
535+
} else {
536+
auto p = static_cast<float>(p_);
537+
auto geometric_func = [p] __device__ (float rand) {
538+
// use __logf fast approximation for peak bandwidth
539+
return static_cast<scalar_t>(::ceil(__logf(rand) / __logf(static_cast<float>(1.0)-p)));
540+
};
541+
distribution_nullary_kernel<scalar_t, float, curand4_engine_calls>(iter,
542+
gen,
543+
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
544+
geometric_func);
545+
}
546+
});
547+
}
548+
549+
template<typename RNG>
550+
struct GeometricKernel {
551+
void operator()(TensorIterator& iter, double p, c10::optional<Generator> gen) {
552+
geometric_kernel(iter, p, check_generator<RNG>(gen));
553+
}
554+
};
555+
521556
}}}}

aten/src/ATen/test/cpu_rng_test.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ Tensor& log_normal_(Tensor& self, double mean, double std, c10::optional<Generat
101101
return at::native::templates::log_normal_impl_<native::templates::cpu::LogNormalKernel, TestCPUGenerator>(self, mean, std, gen);
102102
}
103103

104+
// ================================================== Geometric =======================================================
105+
106+
Tensor& geometric_(Tensor& self, double p, c10::optional<Generator> gen) {
107+
return at::native::templates::geometric_impl_<native::templates::cpu::GeometricKernel, TestCPUGenerator>(self, p, gen);
108+
}
109+
104110
TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
105111
// Random
106112
m.impl_UNBOXED("random_.from", random_from_to);
@@ -119,6 +125,8 @@ TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
119125
m.impl_UNBOXED("cauchy_", custom_rng_cauchy_);
120126
// LogNormal
121127
m.impl_UNBOXED("log_normal_", log_normal_);
128+
// Geometric
129+
m.impl_UNBOXED("geometric_", geometric_);
122130
}
123131

124132
class RNGTest : public ::testing::Test {
@@ -307,4 +315,20 @@ TEST_F(RNGTest, LogNormal) {
307315
ASSERT_TRUE(torch::allclose(actual, expected));
308316
}
309317

318+
// ================================================== Geometric =======================================================
319+
320+
TEST_F(RNGTest, Geometric) {
321+
const auto p = 0.42;
322+
auto gen = at::make_generator<TestCPUGenerator>(42.0);
323+
324+
auto actual = torch::empty({3, 3});
325+
actual.geometric_(p, gen);
326+
327+
auto expected = torch::empty_like(actual);
328+
auto iter = TensorIterator::nullary_op(expected);
329+
native::templates::cpu::geometric_kernel(iter, p, check_generator<TestCPUGenerator>(gen));
330+
331+
ASSERT_TRUE(torch::allclose(actual, expected));
332+
}
333+
310334
}

test/test_torch.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -10224,6 +10224,12 @@ def test_log_normal(self, device, dtype):
1022410224
self.assertEqual(a.dtype, dtype)
1022510225
self.assertEqual(a.size(), torch.Size([1]))
1022610226

10227+
@dtypes(torch.float, torch.double)
10228+
def test_geometric(self, device, dtype):
10229+
a = torch.tensor([10], dtype=dtype, device=device).geometric_(0.5)
10230+
self.assertEqual(a.dtype, dtype)
10231+
self.assertEqual(a.size(), torch.Size([1]))
10232+
1022710233
def test_empty_strided(self, device):
1022810234
for shape in [(2, 3, 4), (0, 2, 0)]:
1022910235
# some of these cases are pretty strange, just verifying that if as_strided
@@ -10361,11 +10367,6 @@ def test_logical_all(self, device):
1036110367
y[-1][-1][-1] = 0
1036210368
self.assertEqual(y, x.all(2, keepdim=True))
1036310369

10364-
def test_geometric(self, device):
10365-
a = torch.tensor([10], dtype=torch.float, device=device).geometric_(0.5)
10366-
self.assertEqual(a.dtype, torch.float)
10367-
self.assertEqual(a.size(), torch.Size([1]))
10368-
1036910370
@dtypes(torch.float32)
1037010371
def test_exponential(self, device, dtype):
1037110372
a = torch.tensor([10], dtype=torch.float, device=device).exponential_(0.5)

0 commit comments

Comments
 (0)