Skip to content
48 changes: 24 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ pip install pywhyllm
PyWhy-LLM seamlessly integrates into your existing causal inference process. Import the necessary classes and start exploring the power of LLM-augmented causal analysis.

```python
from pywhyllm import ModelSuggester, IdentificationSuggester, ValidationSuggester
from pywhyllm.suggesters.model_suggester import ModelSuggester
from pywhyllm.suggesters.identification_suggester import IdentificationSuggester
from pywhyllm.suggesters.validation_suggester import ValidationSuggester
from pywhyllm import RelationshipStrategy

```

Expand All @@ -34,17 +37,20 @@ from pywhyllm import ModelSuggester, IdentificationSuggester, ValidationSuggeste

```python
# Create instance of Modeler
modeler = Modeler()
modeler = ModelSuggester('gpt-4')

all_factors = ["smoking", "lung cancer", "exercise habits", "air pollution exposure"]
treatment = "smoking"
outcome = "lung cancer"

# Suggest a list of domain expertises
domain_expertises = modeler.suggest_domain_expertises(all_factors)

# Suggest a set of potential confounders
suggested_confounders = modeler.suggest_confounders(variables=_variables, treatment=treatment, outcome=outcome, llm=gpt4)
suggested_confounders = modeler.suggest_confounders(treatment, outcome, all_factors, domain_expertises)

# Suggest pair-wise relationship between variables
suggested_dag = modeler.suggest_relationships(variables=selected_variables, llm=gpt4)

plt.figure(figsize=(10, 5))
nx.draw(suggested_dag, with_labels=True, node_color='lightblue')
plt.show()
suggested_dag = modeler.suggest_relationships(treatment, outcome, all_factors, domain_expertises, RelationshipStrategy.Pairwise)
```


Expand All @@ -54,15 +60,13 @@ plt.show()

```python
# Create instance of Identifier
identifier = Identifier()
identifier = IdentificationSuggester('gpt-4')

# Suggest a backdoor set, front door set, and iv set
suggested_backdoor = identifier.suggest_backdoor(llm=gpt4, treatment=treatment, outcome=outcome, confounders=suggested_confounders)
suggested_frontdoor = identifier.suggest_frontdoor(llm=gpt4, treatment=treatment, outcome=outcome, confounders=suggested_confounders)
suggested_iv = identifier.suggest_iv(llm=gpt4, treatment=treatment, outcome=outcome, confounders=suggested_confounders)
# Suggest a backdoor set, mediator set, and iv set
suggested_backdoor = identifier.suggest_backdoor(treatment, outcome, all_factors, domain_expertises)
suggested_mediators = identifier.suggest_mediators(treatment, outcome, all_factors, domain_expertises)
suggested_iv = identifier.suggest_ivs(treatment, outcome, all_factors, domain_expertises)

# Suggest an estimand based on the suggester backdoor set, front door set, and iv set
estimand = identifier.suggest_estimand(confounders=suggested_confounders, treatment=treatment, outcome=outcome, backdoor=suggested_backdoor, frontdoor=suggested_frontdoor, iv=suggested_iv, llm=gpt4)
```


Expand All @@ -72,20 +76,16 @@ estimand = identifier.suggest_estimand(confounders=suggested_confounders, treatm

```python
# Create instance of Validator
validator = Validator()
validator = ValidationSuggester('gpt-4')

# Suggest a critique of the provided DAG
suggested_critiques_dag = validator.critique_graph(graph=suggested_dag, llm=gpt4)
# Suggest a critique of the edges in provided DAG
suggested_critiques_dag = validator.critique_graph(all_factors, suggested_dag, domain_expertises, RelationshipStrategy.Pairwise)

# Suggest latent confounders
suggested_latent_confounders = validator.suggest_latent_confounders(treatment=treatment, outcome=outcome, llm=gpt4)
suggested_latent_confounders = validator.suggest_latent_confounders(treatment, outcome, all_factors, domain_expertises)

# Suggest negative controls
suggested_negative_controls = validator.suggest_negative_controls(variables=selected_variables, treatment=treatment, outcome=outcome, llm=gpt4)

plt.figure(figsize=(10, 5))
nx.draw(suggested_critiques_dag, with_labels=True, node_color='lightblue')
plt.show()
suggested_negative_controls = validator.suggest_negative_controls(treatment, outcome, all_factors, domain_expertises)

```

Expand Down
56 changes: 28 additions & 28 deletions pywhyllm/suggesters/identification_suggester.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, llm=None):
# self,
# treatment: str,
# outcome: str,
# factors_list: list(),
# all_factors: list(),
# llm: guidance.models,
# backdoor: Set[str] = None,
# frontdoor: Set[str] = None,
Expand All @@ -41,7 +41,7 @@ def __init__(self, llm=None):
# backdoor_edges, backdoor_set = self.suggest_backdoor(
# treatment=treatment,
# outcome=outcome,
# factors_list=factors_list,
# all_factors=all_factors,
# llm=llm,
# experts=experts,
# analysis_context=analysis_context,
Expand All @@ -66,7 +66,7 @@ def __init__(self, llm=None):
# frontdoor_edges, frontdoor_set = self.suggest_frontdoor(
# treatment=treatment,
# outcome=outcome,
# factors_list=factors_list,
# all_factors=all_factors,
# llm=llm,
# experts=experts,
# analysis_context=analysis_context,
Expand All @@ -87,7 +87,7 @@ def __init__(self, llm=None):
# ivs_edges, ivs_set = self.suggest_ivs(
# treatment=treatment,
# outcome=outcome,
# factors_list=factors_list,
# all_factors=all_factors,
# llm=llm,
# experts=experts,
# analysis_context=analysis_context,
Expand Down Expand Up @@ -116,15 +116,15 @@ def suggest_backdoor(
self,
treatment: str,
outcome: str,
factors_list: list(),
all_factors: list(),
expertise_list: list(),
analysis_context=CONTEXT,
analysis_context: str = CONTEXT,
stakeholders: list() = None
):
backdoor_set = self.model_suggester.suggest_confounders(
treatment=treatment,
outcome=outcome,
factors_list=factors_list,
all_factors=all_factors,
expertise_list=expertise_list,
analysis_context=analysis_context,
stakeholders=stakeholders
Expand All @@ -136,9 +136,9 @@ def suggest_frontdoor(
self,
treatment: str,
outcome: str,
factors_list: list(),
all_factors: list(),
expertise_list: list(),
analysis_context=CONTEXT,
analysis_context: str = CONTEXT,
stakeholders: list() = None
):
pass
Expand All @@ -147,9 +147,9 @@ def suggest_mediators(
self,
treatment: str,
outcome: str,
factors_list: list(),
all_factors: list(),
expertise_list: list(),
analysis_context=CONTEXT,
analysis_context: str = CONTEXT,
stakeholders: list() = None
):
expert_list: List[str] = list()
Expand All @@ -164,16 +164,16 @@ def suggest_mediators(
mediators_edges[(treatment, outcome)] = 1

edited_factors_list: List[str] = []
for i in range(len(factors_list)):
if factors_list[i] != treatment and factors_list[i] != outcome:
edited_factors_list.append(factors_list[i])
for i in range(len(all_factors)):
if all_factors[i] != treatment and all_factors[i] != outcome:
edited_factors_list.append(all_factors[i])

for expert in expert_list:
mediators_edges, mediators_list = self.request_mediators(
treatment=treatment,
outcome=outcome,
domain_expertise=expert,
factors_list=edited_factors_list,
all_factors=edited_factors_list,
mediators_edges=mediators_edges,
analysis_context=analysis_context
)
Expand All @@ -187,9 +187,9 @@ def request_mediators(
treatment,
outcome,
domain_expertise,
factors_list,
all_factors,
mediators_edges,
analysis_context=CONTEXT
analysis_context: str = CONTEXT
):
mediators: List[str] = list()

Expand Down Expand Up @@ -218,7 +218,7 @@ def request_mediators(
on the causal chain that links the {treatment} to the {outcome}? From your perspective as an expert in
{domain_expertise}, which factor(s) of the following factors, if any at all, mediates, is/are on the causal
chain, that links the {treatment} to the {outcome}? Then provide your step by step chain of thoughts within
the tags <thinking></thinking>. factor_names : {factors_list} Wrap the name of the factor(s), if any at all,
the tags <thinking></thinking>. factor_names : {all_factors} Wrap the name of the factor(s), if any at all,
that has/have a high likelihood of directly influencing and causing the assignment of the {outcome} and also
has/have a high likelihood of being directly influenced and caused by the assignment of the {treatment} within
the tags <mediating_factor>factor_name</mediating_factor>. Where factor_name is one of the items within the
Expand All @@ -237,7 +237,7 @@ def request_mediators(
if mediating_factor:
for factor in mediating_factor:
# to not add it twice into the list
if factor in factors_list and factor not in mediators:
if factor in all_factors and factor not in mediators:
mediators.append(factor)
success = True

Expand All @@ -262,9 +262,9 @@ def suggest_ivs(
self,
treatment: str,
outcome: str,
factors_list: list(),
all_factors: list(),
expertise_list: list(),
analysis_context=CONTEXT,
analysis_context: str = CONTEXT,
stakeholders: list() = None
):
expert_list: List[str] = list()
Expand All @@ -279,17 +279,17 @@ def suggest_ivs(
iv_edges[(treatment, outcome)] = 1

edited_factors_list: List[str] = []
for i in range(len(factors_list)):
if factors_list[i] != treatment and factors_list[i] != outcome:
edited_factors_list.append(factors_list[i])
for i in range(len(all_factors)):
if all_factors[i] != treatment and all_factors[i] != outcome:
edited_factors_list.append(all_factors[i])

for expert in expert_list:
iv_edges, iv_list = self.request_ivs(
treatment=treatment,
outcome=outcome,
analysis_context=analysis_context,
domain_expertise=expert,
factors_list=edited_factors_list,
all_factors=edited_factors_list,
iv_edges=iv_edges,
)

Expand All @@ -305,7 +305,7 @@ def request_ivs(
outcome,
analysis_context,
domain_expertise,
factors_list,
all_factors,
iv_edges
):
ivs: List[str] = list()
Expand Down Expand Up @@ -338,7 +338,7 @@ def request_ivs(
the {outcome}? Which factor(s) of the following factors, if any at all, are (an) instrumental variable(s)
to the causal relationship of the {treatment} causing the {outcome}? Be concise and keep your thinking
within two paragraphs. Then provide your step by step chain of thoughts within the tags
<thinking></thinking>. factor_names : {factors_list} Wrap the name of the factor(s), if there are any at
<thinking></thinking>. factor_names : {all_factors} Wrap the name of the factor(s), if there are any at
all, that both has/have a high likelihood of influecing and causing the {treatment} and has/have a very low
likelihood of influencing and causing the {outcome}, within the tags <iv_factor>factor_name</iv_factor>.
Where factor_name is one of the items within the factor_names list. If a factor does not have a high
Expand All @@ -353,7 +353,7 @@ def request_ivs(

if iv_factors:
for factor in iv_factors:
if factor in factors_list and factor not in ivs:
if factor in all_factors and factor not in ivs:
ivs.append(factor)
success = True

Expand Down
Loading