15
15
from pyspark .sql import DataFrame
16
16
from pyspark .sql import functions as F
17
17
from pyspark .sql .functions import pandas_udf , PandasUDFType
18
- from pyspark .sql .types import StringType , LongType
18
+ from pyspark .sql .types import StringType , LongType , MapType , FloatType
19
19
20
20
import numpy as np
21
21
import pandas as pd
44
44
from fink_science .snn .processor import snn_ia_elasticc , snn_broad_elasticc
45
45
from fink_science .cats .processor import predict_nn
46
46
from fink_science .agn .processor import agn_elasticc
47
- from fink_science .t2 .processor import t2
47
+ from fink_science .slsn .processor import slsn_elasticc
48
+ # from fink_science.t2.processor import t2
48
49
49
50
from fink_broker .tester import spark_unit_tests
50
51
@@ -144,6 +145,22 @@ def ang2pix_array(ra: pd.Series, dec: pd.Series, nside: pd.Series) -> pd.Series:
144
145
145
146
return pd .Series (to_return )
146
147
148
+ @pandas_udf (MapType (StringType (), FloatType ()), PandasUDFType .SCALAR )
149
+ def fake_t2 (incol ):
150
+ """ Return all t2 probabilities as zero
151
+
152
+ Only for test purposes.
153
+ """
154
+ keys = [
155
+ 'M-dwarf' , 'KN' , 'AGN' , 'SLSN-I' ,
156
+ 'RRL' , 'Mira' , 'SNIax' , 'TDE' ,
157
+ 'SNIa' , 'SNIbc' , 'SNIa-91bg' ,
158
+ 'mu-Lens-Single' , 'EB' , 'SNII'
159
+ ]
160
+ values = [0.0 ] * len (keys )
161
+ out = {k : v for k , v in zip (keys , values )}
162
+ return pd .Series ([out ] * len (incol ))
163
+
147
164
def apply_science_modules (df : DataFrame , logger : Logger ) -> DataFrame :
148
165
"""Load and apply Fink science modules to enrich alert content
149
166
@@ -317,9 +334,10 @@ def apply_science_modules(df: DataFrame, logger: Logger) -> DataFrame:
317
334
df = df .withColumn ('rf_kn_vs_nonkn' , knscore (* knscore_args ))
318
335
319
336
logger .info ("New processor: T2" )
320
- t2_args = ['candid' , 'cjd' , 'cfid' , 'cmagpsf' , 'csigmapsf' ]
321
- t2_args += [F .col ('roid' ), F .col ('cdsxmatch' ), F .col ('candidate.jdstarthist' )]
322
- df = df .withColumn ('t2' , t2 (* t2_args ))
337
+ # t2_args = ['candid', 'cjd', 'cfid', 'cmagpsf', 'csigmapsf']
338
+ # t2_args += [F.col('roid'), F.col('cdsxmatch'), F.col('candidate.jdstarthist')]
339
+ # df = df.withColumn('t2', t2(*t2_args))
340
+ df = df .withColumn ('t2' , fake_t2 ('objectId' ))
323
341
324
342
# Apply level one processor: snad (light curve features)
325
343
logger .info ("New processor: ad_features" )
@@ -406,9 +424,18 @@ def apply_science_modules_elasticc(df: DataFrame, logger: Logger) -> DataFrame:
406
424
df = df .withColumn ('redshift_err' , F .col ('diaObject.z_final_err' ))
407
425
408
426
logger .info ("New processor: EarlySN" )
427
+
409
428
args = ['cmidPointTai' , 'cfilterName' , 'cpsFlux' , 'cpsFluxErr' ]
410
- # fake args
411
- args += [F .col ('cdsxmatch' ), F .lit (20 ), F .lit (40 )]
429
+
430
+ # fake cdsxmatch and nobs
431
+ args += [F .col ('cdsxmatch' ), F .lit (20 )]
432
+ args += [F .col ('diaObject.ra' ), F .col ('diaObject.decl' )]
433
+ args += [F .col ('diaObject.hostgal_ra' ), F .col ('diaObject.hostgal_dec' )]
434
+ args += [F .col ('diaObject.hostgal_zphot' )]
435
+ args += [F .col ('diaObject.hostgal_zphot_err' ), F .col ('diaObject.mwebv' )]
436
+
437
+ # maxduration
438
+ args += [F .lit (40 )]
412
439
df = df .withColumn ('rf_snia_vs_nonia' , rfscore_sigmoid_elasticc (* args ))
413
440
414
441
# Apply level one processor: superNNova
@@ -429,7 +456,6 @@ def apply_science_modules_elasticc(df: DataFrame, logger: Logger) -> DataFrame:
429
456
df = df .withColumn ('preds_snn' , snn_broad_elasticc (* args ))
430
457
431
458
mapping_snn = {
432
- - 1 : 0 ,
433
459
0 : 11 ,
434
460
1 : 13 ,
435
461
2 : 12 ,
@@ -449,32 +475,17 @@ def apply_science_modules_elasticc(df: DataFrame, logger: Logger) -> DataFrame:
449
475
df = df .withColumn ('cbpf_preds' , predict_nn (* args ))
450
476
451
477
mapping_cats_general = {
452
- - 1 : 0 ,
453
- 0 : 111 ,
454
- 1 : 112 ,
455
- 2 : 113 ,
456
- 3 : 114 ,
457
- 4 : 115 ,
458
- 5 : 121 ,
459
- 6 : 122 ,
460
- 7 : 123 ,
461
- 8 : 124 ,
462
- 9 : 131 ,
463
- 10 : 132 ,
464
- 11 : 133 ,
465
- 12 : 134 ,
466
- 13 : 135 ,
467
- 14 : 211 ,
468
- 15 : 212 ,
469
- 16 : 213 ,
470
- 17 : 214 ,
471
- 18 : 221
478
+ 0 : 11 ,
479
+ 1 : 12 ,
480
+ 2 : 13 ,
481
+ 3 : 21 ,
482
+ 4 : 22 ,
472
483
}
473
484
mapping_cats_general_expr = F .create_map ([F .lit (x ) for x in chain (* mapping_cats_general .items ())])
474
485
475
- col_fine_class = F . col ( 'cbpf_preds' ). getItem ( 0 ). astype ( 'int' )
476
- df = df .withColumn ('cats_fine_class ' , mapping_cats_general_expr [col_fine_class ])
477
- df = df .withColumn ('cats_fine_max_prob ' , F .col ( 'cbpf_preds' ). getItem ( 1 ))
486
+ df = df . withColumn ( 'argmax' , F . expr ( 'array_position(cbpf_preds, array_max(cbpf_preds)) - 1' ) )
487
+ df = df .withColumn ('cats_broad_class ' , mapping_cats_general_expr [df [ 'argmax' ] ])
488
+ df = df .withColumn ('cats_broad_max_prob ' , F .array_max ( df [ 'cbpf_preds' ] ))
478
489
479
490
# AGN
480
491
args_forced = [
@@ -485,13 +496,18 @@ def apply_science_modules_elasticc(df: DataFrame, logger: Logger) -> DataFrame:
485
496
]
486
497
df = df .withColumn ('rf_agn_vs_nonagn' , agn_elasticc (* args_forced ))
487
498
488
- # T2
489
- df = df .withColumn ('t2_broad_class' , F .lit (0 ))
490
- df = df .withColumn ('t2_broad_max_prob' , F .lit (0.0 ))
499
+ # SLSN
500
+ args_forced = [
501
+ 'diaObject.diaObjectId' , 'cmidPointTai' , 'cpsFlux' , 'cpsFluxErr' , 'cfilterName' ,
502
+ 'diaSource.ra' , 'diaSource.decl' ,
503
+ 'diaObject.hostgal_zphot' , 'diaObject.hostgal_zphot_err' ,
504
+ 'diaObject.hostgal_ra' , 'diaObject.hostgal_dec'
505
+ ]
506
+ df = df .withColumn ('rf_slsn_vs_nonslsn' , slsn_elasticc (* args_forced ))
491
507
492
508
# Drop temp columns
493
509
df = df .drop (* expanded )
494
- df = df .drop (* ['preds_snn' , 'cbpf_preds' , 'redshift' , 'redshift_err' , 'cdsxmatch' , 'roid' ])
510
+ df = df .drop (* ['preds_snn' , 'cbpf_preds' , 'redshift' , 'redshift_err' , 'cdsxmatch' , 'roid' , 'argmax' ])
495
511
496
512
return df
497
513
0 commit comments