Skip to content

Commit

Permalink
lib
Browse files Browse the repository at this point in the history
  • Loading branch information
zphang committed Sep 7, 2020
1 parent cf80f34 commit f88bb82
Show file tree
Hide file tree
Showing 15 changed files with 22 additions and 64 deletions.
43 changes: 0 additions & 43 deletions hf.py

This file was deleted.

Empty file added lm_eval/__init__.py
Empty file.
File renamed without changes.
4 changes: 2 additions & 2 deletions models/__init__.py → lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
import os
from ..base import Registry
from lm_eval.base import Registry

MODEL_REGISTRY = Registry(registry_name="models")
# Load all modules in models directory to populate registry
Expand All @@ -13,7 +13,7 @@
and (file.endswith('.py') or os.path.isdir(path))
):
module_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('lm_evaluation_harness.models.' + module_name)
module = importlib.import_module('lm_eval.models.' + module_name)


def get_model(model_name):
Expand Down
4 changes: 1 addition & 3 deletions models/dummy.py → lm_eval/models/dummy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import transformers
import torch
from ..base import LM
from lm_eval.base import LM
from . import MODEL_REGISTRY


Expand Down
4 changes: 2 additions & 2 deletions models/gpt2.py → lm_eval/models/gpt2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import transformers
import torch
import torch.nn.functional as F
from ..base import LM
from .. import utils
from lm_eval.base import LM
from lm_eval import utils
from . import MODEL_REGISTRY


Expand Down
7 changes: 4 additions & 3 deletions models/gpt3.py → lm_eval/models/gpt3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import openai
import transformers
from ..base import LM
from .. import utils
from lm_eval.base import LM
from lm_eval import utils
from . import MODEL_REGISTRY


Expand All @@ -15,7 +15,7 @@ def __init__(self, engine):
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]

@classmethod
def create_from_args(cls, arg_string):
def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci"))

Expand All @@ -37,6 +37,7 @@ def loglikelihood(self, context, continuation):
response = openai.Completion.create(
engine=self.engine,
prompt=full_text,
echo=True,
max_tokens=0, temperature=0.0,
logprobs=0,
)
Expand Down
4 changes: 2 additions & 2 deletions tasks/__init__.py → lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
import os
from ..base import Registry
from lm_eval.base import Registry

TASK_REGISTRY = Registry(registry_name="tasks")
# Load all modules in models directory to populate registry
Expand All @@ -13,7 +13,7 @@
and (file.endswith('.py') or os.path.isdir(path))
):
module_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('lm_evaluation_harness.tasks.' + module_name)
module = importlib.import_module('lm_eval.tasks.' + module_name)


ALL_TASKS = sorted(list(TASK_REGISTRY.registry))
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tasks/coqa.py → lm_eval/tasks/coqa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import random
from ..base import Dataset
from lm_eval.base import Dataset
from . import TASK_REGISTRY


Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
18 changes: 10 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,43 @@
import argparse
import json

import models
import tasks
from lm_eval import models, tasks


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--new_fewshot', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=1)
return parser.parse_args()


def main():
args = parse_args()
model = models.get_model(args.model).create_from_arg_string(args.model_args)
lm = models.get_model(args.model).create_from_arg_string(args.model_args)
if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS
else:
task_names = args.tasks.split(",")
task_list = {
task_dict = {
task_name: tasks.get_task(task_name)()
for task_name in task_names
}
results = {}
for task_name, task in task_list:
for task_name, task in task_dict.items():
if not task.has_validation_docs():
continue
result = task.evaluate(
docs=task.validation_docs(),
lm=lm,
provide_description=args.provide_description,
num_fewshot=args.new_fewshot,
num_fewshot=args.num_fewshot,
)
results[task_name] = result
print(json.dumps(results, indent=2))


if __name__ == "__main__":
main()
main()

0 comments on commit f88bb82

Please sign in to comment.