@@ -100,19 +100,13 @@ that). It should always be called by at least 32 threads to ensure the random st
100100initialized, even if you will be using the generator from fewer threads!
101101"""
102102@inline Base. @propagate_inbounds function Random. seed! (rng:: SharedTauswortheGenerator , seed)
103- # 0-indexed so that we can bitwise and instead of mod1
104- tid0 = threadIdx (). x - 1 + (threadIdx (). y - 1 ) * blockDim (). x +
105- (threadIdx (). z - 1 ) * blockDim (). x * blockDim (). y
106103 state = initial_state (seed)
107- @inbounds rng. state[tid0 & 31 + 1 ] = state
104+ @inbounds rng. state[laneid () ] = state
108105 return
109106end
110107
111108@inline Base. @propagate_inbounds function initial_state (seeds)
112- # 0-indexed so that we can bitwise and instead of mod1
113- tid0 = threadIdx (). x - 1 + (threadIdx (). y - 1 ) * blockDim (). x +
114- (threadIdx (). z - 1 ) * blockDim (). x * blockDim (). y
115- z = seeds[tid0& 31 + 1 ]
109+ z = seeds[laneid ()]
116110
117111 # add the block id to ensure unique values across blocks
118112 # XXX : is this OK? shouldn't we use a generator that allows skipping ahead?
@@ -166,28 +160,30 @@ Generate a byte of random data using the on-device Tausworthe generator.
166160 kernel may deadlock.
167161"""
168162function Random. rand (rng:: SharedTauswortheGenerator , :: Type{UInt32} )
169- # 0-indexed so that we can bitwise and instead of mod1
170- tid0 = threadIdx (). x - 1 + (threadIdx (). y - 1 ) * blockDim (). x +
171- (threadIdx (). z - 1 ) * blockDim (). x * blockDim (). y
172- i = tid0& 31 + 1
173- j = tid0& 3 + 1
163+ @inline pow2_mod1 (x, y) = (x- 1 )& (y- 1 ) + 1
164+
165+ i = laneid ()
166+ j = pow2_mod1 (i, 4 )
174167
175168 @inbounds begin
176- # get
169+ # get state
177170 z = rng. state[i]
178171 if z == 0
179172 z = initial_state (rng. seed)
180173 end
181174
182- sync_threads () # XXX : this implies that rand() cannot be called from a branch
175+ sync_threads ()
183176
184- # advance
177+ # advance & update state
185178 S1, S2, S3, M = TausShift1 ()[j], TausShift2 ()[j], TausShift3 ()[j], TausOffset ()[j]
186179 rng. state[i] = TausStep (z, S1, S2, S3, M)
187180
188181 sync_threads ()
189182
190- # update
191- rng. state[tid0& 31 + 1 ] ⊻ rng. state[(tid0+ 1 )& 31 + 1 ] ⊻ rng. state[(tid0+ 1 )& 31 + 1 ] ⊻ rng. state[(tid0+ 1 )& 31 + 1 ]
183+ # generate
184+ rng. state[i] ⊻
185+ rng. state[pow2_mod1 (i+ 1 , 32 )] ⊻
186+ rng. state[pow2_mod1 (i+ 2 , 32 )] ⊻
187+ rng. state[pow2_mod1 (i+ 3 , 32 )]
192188 end
193189end
0 commit comments