Skip to content

Commit d537d73

Browse files
authored
Adding first set of tests for Suggesters
Wrote test classes and methods for the following classes: SimpleModelSuggester SimpleIdentificationSuggester ModelSuggester IdentificationSuggeste
2 parents f22a192 + c149fb1 commit d537d73

16 files changed

+620
-216
lines changed

pywhyllm/suggesters/identification_suggester.py

Lines changed: 26 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
1-
from typing import List, Dict, Set, Tuple, Protocol
1+
from typing import List, Dict, Tuple
22
from ..protocols import IdentifierProtocol
3-
from ..helpers import RelationshipStrategy, ModelType
43
from .model_suggester import ModelSuggester
5-
from ..prompts import prompts as ps
64
import guidance
75
from guidance import system, user, assistant, gen
86
import re
97

108

119
class IdentificationSuggester(IdentifierProtocol):
12-
1310
CONTEXT: str = """causal mechanisms"""
1411

15-
def __init__(self, llm):
16-
if llm == 'gpt-4':
17-
self.llm = guidance.models.OpenAI('gpt-4')
18-
self.model_suggester = ModelSuggester('gpt-4')
12+
def __init__(self, llm=None):
13+
if llm is not None:
14+
if (llm == 'gpt-4'):
15+
self.llm = guidance.models.OpenAI('gpt-4')
16+
self.model_suggester = ModelSuggester('gpt-4')
1917

2018
# def suggest_estimand(
2119
# self,
@@ -120,7 +118,7 @@ def suggest_backdoor(
120118
outcome: str,
121119
factors_list: list(),
122120
expertise_list: list(),
123-
analysis_context: list() = CONTEXT,
121+
analysis_context=CONTEXT,
124122
stakeholders: list() = None
125123
):
126124
backdoor_set = self.model_suggester.suggest_confounders(
@@ -133,14 +131,14 @@ def suggest_backdoor(
133131
)
134132
return backdoor_set
135133

136-
#TODO:implement
134+
# TODO:implement
137135
def suggest_frontdoor(
138136
self,
139137
treatment: str,
140138
outcome: str,
141139
factors_list: list(),
142140
expertise_list: list(),
143-
analysis_context: list() = CONTEXT,
141+
analysis_context=CONTEXT,
144142
stakeholders: list() = None
145143
):
146144
pass
@@ -151,7 +149,7 @@ def suggest_mediators(
151149
outcome: str,
152150
factors_list: list(),
153151
expertise_list: list(),
154-
analysis_context: list() = CONTEXT,
152+
analysis_context=CONTEXT,
155153
stakeholders: list() = None
156154
):
157155
expert_list: List[str] = list()
@@ -170,43 +168,28 @@ def suggest_mediators(
170168
if factors_list[i] != treatment and factors_list[i] != outcome:
171169
edited_factors_list.append(factors_list[i])
172170

173-
if len(expert_list) > 1:
174-
for expert in expert_list:
175-
mediators_edges, mediators_list = self.request_mediators(
176-
treatment=treatment,
177-
outcome=outcome,
178-
analysis_context=analysis_context,
179-
domain_expertise=expert,
180-
factors_list=edited_factors_list,
181-
mediators_edges=mediators_edges
182-
)
183-
for m in mediators_list:
184-
if m not in mediators:
185-
mediators.append(m)
186-
else:
171+
for expert in expert_list:
187172
mediators_edges, mediators_list = self.request_mediators(
188173
treatment=treatment,
189174
outcome=outcome,
190-
analysis_context=analysis_context,
191-
domain_expertise=expert_list[0],
175+
domain_expertise=expert,
192176
factors_list=edited_factors_list,
193177
mediators_edges=mediators_edges,
178+
analysis_context=analysis_context
194179
)
195-
196180
for m in mediators_list:
197181
if m not in mediators:
198182
mediators.append(m)
199-
200183
return mediators_edges, mediators
201184

202185
def request_mediators(
203186
self,
204187
treatment,
205188
outcome,
206-
analysis_context,
207189
domain_expertise,
208190
factors_list,
209-
mediators_edges
191+
mediators_edges,
192+
analysis_context=CONTEXT
210193
):
211194
mediators: List[str] = list()
212195

@@ -254,9 +237,7 @@ def request_mediators(
254237
# to not add it twice into the list
255238
if factor in factors_list and factor not in mediators:
256239
mediators.append(factor)
257-
success = True
258-
else:
259-
success = False
240+
success = True
260241

261242
except KeyError:
262243
success = False
@@ -281,7 +262,7 @@ def suggest_ivs(
281262
outcome: str,
282263
factors_list: list(),
283264
expertise_list: list(),
284-
analysis_context: list() = CONTEXT,
265+
analysis_context=CONTEXT,
285266
stakeholders: list() = None
286267
):
287268
expert_list: List[str] = list()
@@ -300,26 +281,20 @@ def suggest_ivs(
300281
if factors_list[i] != treatment and factors_list[i] != outcome:
301282
edited_factors_list.append(factors_list[i])
302283

303-
if len(expert_list) > 1:
304-
for expert in expert_list:
305-
self.request_ivs(
306-
treatment=treatment,
307-
outcome=outcome,
308-
analysis_context=analysis_context,
309-
domain_expertise=expert,
310-
factors_list=edited_factors_list,
311-
iv_edges=iv_edges,
312-
)
313-
else:
314-
self.request_ivs(
284+
for expert in expert_list:
285+
iv_edges, iv_list = self.request_ivs(
315286
treatment=treatment,
316287
outcome=outcome,
317288
analysis_context=analysis_context,
318-
domain_expertise=expert_list[0],
289+
domain_expertise=expert,
319290
factors_list=edited_factors_list,
320291
iv_edges=iv_edges,
321292
)
322293

294+
for m in iv_list:
295+
if m not in ivs:
296+
ivs.append(m)
297+
323298
return iv_edges, ivs
324299

325300
def request_ivs(
@@ -376,9 +351,7 @@ def request_ivs(
376351
for factor in iv_factors:
377352
if factor in factors_list and factor not in ivs:
378353
ivs.append(factor)
379-
success = True
380-
else:
381-
success = False
354+
success = True
382355

383356
except KeyError:
384357
success = False
@@ -390,4 +363,4 @@ def request_ivs(
390363
else:
391364
iv_edges[(element, treatment)] = 1
392365

393-
return iv_edges
366+
return iv_edges, ivs

0 commit comments

Comments
 (0)