1
- from typing import List , Dict , Set , Tuple , Protocol
1
+ from typing import List , Dict , Tuple
2
2
from ..protocols import IdentifierProtocol
3
- from ..helpers import RelationshipStrategy , ModelType
4
3
from .model_suggester import ModelSuggester
5
- from ..prompts import prompts as ps
6
4
import guidance
7
5
from guidance import system , user , assistant , gen
8
6
import re
9
7
10
8
11
9
class IdentificationSuggester (IdentifierProtocol ):
12
-
13
10
CONTEXT : str = """causal mechanisms"""
14
11
15
- def __init__ (self , llm ):
16
- if llm == 'gpt-4' :
17
- self .llm = guidance .models .OpenAI ('gpt-4' )
18
- self .model_suggester = ModelSuggester ('gpt-4' )
12
+ def __init__ (self , llm = None ):
13
+ if llm is not None :
14
+ if (llm == 'gpt-4' ):
15
+ self .llm = guidance .models .OpenAI ('gpt-4' )
16
+ self .model_suggester = ModelSuggester ('gpt-4' )
19
17
20
18
# def suggest_estimand(
21
19
# self,
@@ -120,7 +118,7 @@ def suggest_backdoor(
120
118
outcome : str ,
121
119
factors_list : list (),
122
120
expertise_list : list (),
123
- analysis_context : list () = CONTEXT ,
121
+ analysis_context = CONTEXT ,
124
122
stakeholders : list () = None
125
123
):
126
124
backdoor_set = self .model_suggester .suggest_confounders (
@@ -133,14 +131,14 @@ def suggest_backdoor(
133
131
)
134
132
return backdoor_set
135
133
136
- #TODO:implement
134
+ # TODO:implement
137
135
def suggest_frontdoor (
138
136
self ,
139
137
treatment : str ,
140
138
outcome : str ,
141
139
factors_list : list (),
142
140
expertise_list : list (),
143
- analysis_context : list () = CONTEXT ,
141
+ analysis_context = CONTEXT ,
144
142
stakeholders : list () = None
145
143
):
146
144
pass
@@ -151,7 +149,7 @@ def suggest_mediators(
151
149
outcome : str ,
152
150
factors_list : list (),
153
151
expertise_list : list (),
154
- analysis_context : list () = CONTEXT ,
152
+ analysis_context = CONTEXT ,
155
153
stakeholders : list () = None
156
154
):
157
155
expert_list : List [str ] = list ()
@@ -170,43 +168,28 @@ def suggest_mediators(
170
168
if factors_list [i ] != treatment and factors_list [i ] != outcome :
171
169
edited_factors_list .append (factors_list [i ])
172
170
173
- if len (expert_list ) > 1 :
174
- for expert in expert_list :
175
- mediators_edges , mediators_list = self .request_mediators (
176
- treatment = treatment ,
177
- outcome = outcome ,
178
- analysis_context = analysis_context ,
179
- domain_expertise = expert ,
180
- factors_list = edited_factors_list ,
181
- mediators_edges = mediators_edges
182
- )
183
- for m in mediators_list :
184
- if m not in mediators :
185
- mediators .append (m )
186
- else :
171
+ for expert in expert_list :
187
172
mediators_edges , mediators_list = self .request_mediators (
188
173
treatment = treatment ,
189
174
outcome = outcome ,
190
- analysis_context = analysis_context ,
191
- domain_expertise = expert_list [0 ],
175
+ domain_expertise = expert ,
192
176
factors_list = edited_factors_list ,
193
177
mediators_edges = mediators_edges ,
178
+ analysis_context = analysis_context
194
179
)
195
-
196
180
for m in mediators_list :
197
181
if m not in mediators :
198
182
mediators .append (m )
199
-
200
183
return mediators_edges , mediators
201
184
202
185
def request_mediators (
203
186
self ,
204
187
treatment ,
205
188
outcome ,
206
- analysis_context ,
207
189
domain_expertise ,
208
190
factors_list ,
209
- mediators_edges
191
+ mediators_edges ,
192
+ analysis_context = CONTEXT
210
193
):
211
194
mediators : List [str ] = list ()
212
195
@@ -254,9 +237,7 @@ def request_mediators(
254
237
# to not add it twice into the list
255
238
if factor in factors_list and factor not in mediators :
256
239
mediators .append (factor )
257
- success = True
258
- else :
259
- success = False
240
+ success = True
260
241
261
242
except KeyError :
262
243
success = False
@@ -281,7 +262,7 @@ def suggest_ivs(
281
262
outcome : str ,
282
263
factors_list : list (),
283
264
expertise_list : list (),
284
- analysis_context : list () = CONTEXT ,
265
+ analysis_context = CONTEXT ,
285
266
stakeholders : list () = None
286
267
):
287
268
expert_list : List [str ] = list ()
@@ -300,26 +281,20 @@ def suggest_ivs(
300
281
if factors_list [i ] != treatment and factors_list [i ] != outcome :
301
282
edited_factors_list .append (factors_list [i ])
302
283
303
- if len (expert_list ) > 1 :
304
- for expert in expert_list :
305
- self .request_ivs (
306
- treatment = treatment ,
307
- outcome = outcome ,
308
- analysis_context = analysis_context ,
309
- domain_expertise = expert ,
310
- factors_list = edited_factors_list ,
311
- iv_edges = iv_edges ,
312
- )
313
- else :
314
- self .request_ivs (
284
+ for expert in expert_list :
285
+ iv_edges , iv_list = self .request_ivs (
315
286
treatment = treatment ,
316
287
outcome = outcome ,
317
288
analysis_context = analysis_context ,
318
- domain_expertise = expert_list [ 0 ] ,
289
+ domain_expertise = expert ,
319
290
factors_list = edited_factors_list ,
320
291
iv_edges = iv_edges ,
321
292
)
322
293
294
+ for m in iv_list :
295
+ if m not in ivs :
296
+ ivs .append (m )
297
+
323
298
return iv_edges , ivs
324
299
325
300
def request_ivs (
@@ -376,9 +351,7 @@ def request_ivs(
376
351
for factor in iv_factors :
377
352
if factor in factors_list and factor not in ivs :
378
353
ivs .append (factor )
379
- success = True
380
- else :
381
- success = False
354
+ success = True
382
355
383
356
except KeyError :
384
357
success = False
@@ -390,4 +363,4 @@ def request_ivs(
390
363
else :
391
364
iv_edges [(element , treatment )] = 1
392
365
393
- return iv_edges
366
+ return iv_edges , ivs
0 commit comments