@@ -101,6 +101,12 @@ Tensor& log_normal_(Tensor& self, double mean, double std, c10::optional<Generat
101
101
return at::native::templates::log_normal_impl_<native::templates::cpu::LogNormalKernel, TestCPUGenerator>(self, mean, std, gen);
102
102
}
103
103
104
+ // ================================================== Geometric =======================================================
105
+
106
+ Tensor& geometric_ (Tensor& self, double p, c10::optional<Generator> gen) {
107
+ return at::native::templates::geometric_impl_<native::templates::cpu::GeometricKernel, TestCPUGenerator>(self, p, gen);
108
+ }
109
+
104
110
TORCH_LIBRARY_IMPL (aten, CustomRNGKeyId, m) {
105
111
// Random
106
112
m.impl_UNBOXED (" random_.from" , random_from_to);
@@ -119,6 +125,8 @@ TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
119
125
m.impl_UNBOXED (" cauchy_" , custom_rng_cauchy_);
120
126
// LogNormal
121
127
m.impl_UNBOXED (" log_normal_" , log_normal_);
128
+ // Geometric
129
+ m.impl_UNBOXED (" geometric_" , geometric_);
122
130
}
123
131
124
132
class RNGTest : public ::testing::Test {
@@ -307,4 +315,20 @@ TEST_F(RNGTest, LogNormal) {
307
315
ASSERT_TRUE (torch::allclose (actual, expected));
308
316
}
309
317
318
+ // ================================================== Geometric =======================================================
319
+
320
+ TEST_F (RNGTest, Geometric) {
321
+ const auto p = 0.42 ;
322
+ auto gen = at::make_generator<TestCPUGenerator>(42.0 );
323
+
324
+ auto actual = torch::empty ({3 , 3 });
325
+ actual.geometric_ (p, gen);
326
+
327
+ auto expected = torch::empty_like (actual);
328
+ auto iter = TensorIterator::nullary_op (expected);
329
+ native::templates::cpu::geometric_kernel (iter, p, check_generator<TestCPUGenerator>(gen));
330
+
331
+ ASSERT_TRUE (torch::allclose (actual, expected));
332
+ }
333
+
310
334
}
0 commit comments