Skip to content

Commit 409a4b2

Browse files
committed
Revert accidental structured sparse changes
*Remove `structured_prefix` for Ops that map 0->0 *Do not introduce new structured ops that don't map 0->0 besides the ones that already existed
1 parent 5c350ab commit 409a4b2

File tree

2 files changed

+41
-90
lines changed

2 files changed

+41
-90
lines changed

pytensor/sparse/math.py

Lines changed: 21 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def wrapper(*args):
4141

4242

4343
@structured_elemwise(ptm.abs)
44-
def structured_abs(x):
44+
def abs(x):
4545
"""
4646
Compute abs(x) for all non-zero elements of x.
4747
"""
@@ -61,34 +61,13 @@ def structured_exp(x):
6161
"""
6262

6363

64-
@structured_elemwise(ptm.exp2)
65-
def structured_exp2(x):
66-
"""
67-
Compute exp2(x) for all non-zero elements of x.
68-
"""
69-
70-
7164
@structured_elemwise(ptm.log)
7265
def structured_log(x):
7366
"""
7467
Compute log(x) for all non-zero elements of x.
7568
"""
7669

7770

78-
@structured_elemwise(ptm.log2)
79-
def structured_log2(x):
80-
"""
81-
Compute log2(x) for all non-zero elements of x.
82-
"""
83-
84-
85-
@structured_elemwise(ptm.log10)
86-
def structured_log10(x):
87-
"""
88-
Compute log10(x) for all non-zero elements of x.
89-
"""
90-
91-
9271
@structured_elemwise(ptm.pow)
9372
def structured_pow(x, y):
9473
"""
@@ -118,161 +97,133 @@ def structured_add(x, y):
11897

11998

12099
@structured_elemwise(ptm.sin)
121-
def structured_sin(x):
100+
def sin(x):
122101
"""
123102
Compute sin(x) for all non-zero elements of x.
124103
"""
125104

126105

127106
@structured_elemwise(ptm.sinh)
128-
def structured_sinh(x):
107+
def sinh(x):
129108
"""
130109
Compute sinh(x) for all non-zero elements of x.
131110
"""
132111

133112

134113
@structured_elemwise(ptm.arcsin)
135-
def structured_arcsin(x):
114+
def arcsin(x):
136115
"""
137116
Compute arcsin(x) for all non-zero elements of x.
138117
"""
139118

140119

141120
@structured_elemwise(ptm.arcsinh)
142-
def structured_arcsinh(x):
121+
def arcsinh(x):
143122
"""
144123
Compute arcsinh(x) for all non-zero elements of x.
145124
"""
146125

147126

148-
@structured_elemwise(ptm.cos)
149-
def structured_cos(x):
150-
"""
151-
Compute cos(x) for all non-zero elements of x.
152-
"""
153-
154-
155-
@structured_elemwise(ptm.cosh)
156-
def structured_cosh(x):
157-
"""
158-
Compute cosh(x) for all non-zero elements of x.
159-
"""
160-
161-
162-
@structured_elemwise(ptm.arccos)
163-
def structured_arccos(x):
164-
"""
165-
Compute arccos(x) for all non-zero elements of x.
166-
"""
167-
168-
169-
@structured_elemwise(ptm.arccosh)
170-
def structured_arccosh(x):
171-
"""
172-
Compute arccosh(x) for all non-zero elements of x.
173-
"""
174-
175-
176127
@structured_elemwise(ptm.tan)
177-
def structured_tan(x):
128+
def tan(x):
178129
"""
179130
Compute tan(x) for all non-zero elements of x.
180131
"""
181132

182133

183134
@structured_elemwise(ptm.tanh)
184-
def structured_tanh(x):
135+
def tanh(x):
185136
"""
186137
Compute tanh(x) for all non-zero elements of x.
187138
"""
188139

189140

190141
@structured_elemwise(ptm.arctan)
191-
def structured_arctan(x):
142+
def arctan(x):
192143
"""
193144
Compute arctan(x) for all non-zero elements of x.
194145
"""
195146

196147

197148
@structured_elemwise(ptm.arctanh)
198-
def structured_arctanh(x):
149+
def arctanh(x):
199150
"""
200151
Compute arctanh(x) for all non-zero elements of x.
201152
"""
202153

203154

204155
@structured_elemwise(ptm.round_half_to_even)
205-
def structured_rint(x):
156+
def rint(x):
206157
"""
207158
Compute round_half_to_even(x) for all non-zero elements of x.
208159
"""
209160

210161

211162
@structured_elemwise(ptm.sign)
212-
def structured_sign(x):
163+
def sign(x):
213164
"""
214165
Compute sign(x) for all non-zero elements of x.
215166
"""
216167

217168

218169
@structured_elemwise(ptm.ceil)
219-
def structured_ceil(x):
170+
def ceil(x):
220171
"""
221172
Compute ceil(x) for all non-zero elements of x.
222173
"""
223174

224175

225176
@structured_elemwise(ptm.floor)
226-
def structured_floor(x):
177+
def floor(x):
227178
"""
228179
Compute floor(x) for all non-zero elements of x.
229180
"""
230181

231182

232183
@structured_elemwise(ptm.log1p)
233-
def structured_log1p(x):
184+
def log1p(x):
234185
"""
235186
Compute log(1 + x) for all non-zero elements of x.
236187
"""
237188

238189

239190
@structured_elemwise(ptm.expm1)
240-
def structured_expm1(x):
191+
def expm1(x):
241192
"""
242193
Compute exp(x) - 1 for all non-zero elements of x.
243194
"""
244195

245196

246197
@structured_elemwise(ptm.deg2rad)
247-
def structured_deg2rad(x):
198+
def deg2rad(x):
248199
"""
249200
Convert degrees to radians for all non-zero elements of x.
250201
"""
251202

252203

253204
@structured_elemwise(ptm.rad2deg)
254-
def structured_rad2deg(x):
205+
def rad2deg(x):
255206
"""
256207
Convert radians to degrees for all non-zero elements of x.
257208
"""
258209

259210

260211
@structured_elemwise(ptm.trunc)
261-
def structured_trunc(x):
212+
def trunc(x):
262213
"""
263214
Truncate the decimal part of x for all non-zero elements of x.
264215
"""
265216

266217

267218
@structured_elemwise(ptm.sqr)
268-
def structured_sqr(x):
219+
def sqr(x):
269220
"""
270221
Compute sqr(x) for all non-zero elements of x.
271222
"""
272223

273224

274225
@structured_elemwise(ptm.sqrt)
275-
def structured_sqrt(x):
226+
def sqrt(x):
276227
"""
277228
Compute sqrt(x) for all non-zero elements of x.
278229
"""
@@ -292,7 +243,7 @@ def conjugate(x):
292243
return _conj(_x)
293244

294245

295-
structured_conjugate = conjugate
246+
structured_conjugate = conj = conjugate
296247

297248

298249
class SpSum(Op):

tests/sparse/test_math.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,32 +1382,32 @@ def test_grad(self):
13821382
name="StructuredAddTester",
13831383
)
13841384

1385-
SinTester = elemwise_checker(psm.structured_sin, np.sin)
1385+
SinTester = elemwise_checker(psm.sin, np.sin)
13861386

1387-
TanTester = elemwise_checker(psm.structured_tan, np.tan, gap=(-1, 1))
1387+
TanTester = elemwise_checker(psm.tan, np.tan, gap=(-1, 1))
13881388

13891389
ArcsinTester = elemwise_checker(
1390-
psm.structured_arcsin, np.arcsin, gap=(-1, 1), gap_grad=(-0.99, 0.99)
1390+
psm.arcsinh, np.arcsin, gap=(-1, 1), gap_grad=(-0.99, 0.99)
13911391
)
13921392

1393-
ArctanTester = elemwise_checker(psm.structured_arctan, np.arctan)
1393+
ArctanTester = elemwise_checker(psm.arctan, np.arctan)
13941394

1395-
SinhTester = elemwise_checker(psm.structured_sinh, np.sinh)
1395+
SinhTester = elemwise_checker(psm.sinh, np.sinh)
13961396

1397-
ArcsinhTester = elemwise_checker(psm.structured_arcsinh, np.arcsinh, gap=(-1, 1))
1397+
ArcsinhTester = elemwise_checker(psm.arcsinh, np.arcsinh, gap=(-1, 1))
13981398

1399-
TanhTester = elemwise_checker(psm.structured_tanh, np.tanh, gap=(-1, 1))
1399+
TanhTester = elemwise_checker(psm.tanh, np.tanh, gap=(-1, 1))
14001400

14011401
ArctanhTester = elemwise_checker(
1402-
psm.structured_arctanh, np.arctanh, gap=(-0.9, 1), gap_grad=(-0.9, 0.95)
1402+
psm.arctanh, np.arctanh, gap=(-0.9, 1), gap_grad=(-0.9, 0.95)
14031403
)
14041404

14051405
RintTester = elemwise_checker(
1406-
psm.structured_rint, np.rint, grad_test=False, test_dtypes=float_dtypes
1406+
psm.rint, np.rint, grad_test=False, test_dtypes=float_dtypes
14071407
)
14081408

14091409
SgnTester = elemwise_checker(
1410-
psm.structured_sign,
1410+
psm.sign,
14111411
np.sign,
14121412
grad_test=False,
14131413
test_dtypes=[
@@ -1416,46 +1416,46 @@ def test_grad(self):
14161416
)
14171417

14181418
CeilTester = elemwise_checker(
1419-
psm.structured_ceil,
1419+
psm.ceil,
14201420
np.ceil,
14211421
grad_test=False,
14221422
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
14231423
)
14241424

14251425
FloorTester = elemwise_checker(
1426-
psm.structured_floor,
1426+
psm.floor,
14271427
np.floor,
14281428
grad_test=False,
14291429
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
14301430
)
14311431

1432-
Log1pTester = elemwise_checker(psm.structured_log1p, np.log1p, gap=(0.5, 10))
1432+
Log1pTester = elemwise_checker(psm.log1p, np.log1p, gap=(0.5, 10))
14331433

1434-
Expm1Tester = elemwise_checker(psm.structured_expm1, np.expm1)
1434+
Expm1Tester = elemwise_checker(psm.expm1, np.expm1)
14351435

14361436
Deg2radTester = elemwise_checker(
1437-
psm.structured_deg2rad,
1437+
psm.deg2rad,
14381438
np.deg2rad,
14391439
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
14401440
)
14411441

14421442
Rad2degTester = elemwise_checker(
1443-
psm.structured_rad2deg,
1443+
psm.rad2deg,
14441444
np.rad2deg,
14451445
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
14461446
)
14471447

14481448

14491449
TruncTester = elemwise_checker(
1450-
psm.structured_trunc,
1450+
psm.trunc,
14511451
np.trunc,
14521452
test_dtypes=[m for m in all_dtypes if m not in complex_dtypes],
14531453
grad_test=False,
14541454
)
14551455

14561456

1457-
SqrTester = elemwise_checker(psm.structured_sqr, lambda x: x * x)
1457+
SqrTester = elemwise_checker(psm.sqr, lambda x: x * x)
14581458

1459-
SqrtTester = elemwise_checker(psm.structured_sqrt, np.sqrt, gap=(0, 10))
1459+
SqrtTester = elemwise_checker(psm.sqrt, np.sqrt, gap=(0, 10))
14601460

1461-
ConjTester = elemwise_checker(psm.structured_conjugate, np.conj, grad_test=False)
1461+
ConjTester = elemwise_checker(psm.conjugate, np.conj, grad_test=False)

0 commit comments

Comments
 (0)