Skip to content

Commit b2805df

Browse files
committed
Add pyproject and precommit configurations for black, isort and mypy and reformat the code accordingly
1 parent 47ba6b6 commit b2805df

25 files changed

+279
-404
lines changed

.pre-commit-config.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
files: "^src/"
2+
3+
repos:
4+
- repo: https://github.com/psf/black
5+
rev: 22.12.0
6+
hooks:
7+
- id: black
8+
# It is recommended to specify the latest version of Python
9+
# supported by your project here, or alternatively use
10+
# pre-commit's default_language_version, see
11+
# https://pre-commit.com/#top_level-default_language_version
12+
language_version: python3.10
13+
- repo: https://github.com/pycqa/isort
14+
rev: 5.11.2
15+
hooks:
16+
- id: isort
17+
args: ["--profile", "black"]
18+
name: isort (python)
19+
- repo: https://github.com/pre-commit/mirrors-mypy
20+
rev: "v0.991" # Use the sha / tag you want to point at
21+
hooks:
22+
- id: mypy

calc_time.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import json
22
import sys
33
from pathlib import Path
4-
from matplotlib import pyplot as plt
4+
55
import numpy as np
6+
from matplotlib import pyplot as plt
67

78
folder = Path(sys.argv[1])
89
all_files = folder.glob("**/*.json")

evaluate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from matplotlib import pyplot as plt
21
import json
32
import sys
4-
from pathlib import Path
5-
from typing import TypedDict, List
63
from argparse import ArgumentParser
4+
from pathlib import Path
5+
from typing import List, TypedDict
6+
7+
from matplotlib import pyplot as plt
78
from tqdm import tqdm
89

910

pyproject.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[tool.black]
2+
include = '\.pyi?$'
3+
line-length = 88
4+
target-version = ["py310"]
5+
6+
[tool.isort]
7+
line_length = 88
8+
profile = "black"
9+
skip_gitignore = true
10+
11+
[tool.mypy]
12+
check_untyped_defs = true
13+
follow_imports = "silent"
14+
ignore_missing_imports = true
15+
modules = ["main"]
16+
mypy_path = "src"
17+
packages = ["realm"]
18+
python_version = "3.10"

requirements.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/realm/analyze/__init__.py

Whitespace-only changes.

src/realm/analyze/java_syntax.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

src/realm/jdt_lsp.py renamed to src/realm/analyzer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
import itertools
2-
import pickle
32
import subprocess
43
from multiprocessing import Process
54
from multiprocessing.connection import Connection
6-
import os
75
from os import PathLike
86
from pathlib import Path
97
from typing import Any, Dict, List, Optional, cast
108

11-
import torch
12-
9+
from realm import utils
1310
from realm.generation_defs import GenerationContext, Memorization
1411
from realm.lsp import LSPClient, TextFile, spec
1512
from realm.model import CodeT5ForRealm
16-
from realm import utils
1713

1814
TIMEOUT_THRESHOULD = 300
1915

src/realm/config.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from dataclasses import dataclass
2-
from pathlib import Path
32
from enum import Enum, auto
4-
from .generation_defs import LMInferenceConfig
5-
from .utils import JsonSerializable
3+
from pathlib import Path
64
from typing import Any
75

6+
from .utils import JsonSerializable
7+
88

99
@dataclass(frozen=True)
1010
class MetaConfig(JsonSerializable):
@@ -74,6 +74,31 @@ def from_json(cls, d: str) -> "SynthesisMethod":
7474
SYNTHESIS_METHOD_REV_MAP = {value: key for key, value in SYNTHESIS_METHOD_MAP.items()}
7575

7676

77+
@dataclass(frozen=True)
78+
class LMInferenceConfig(JsonSerializable):
79+
batch_size: int
80+
temperature: float
81+
top_k: int
82+
max_new_tokens: int
83+
84+
def to_json(self) -> Any:
85+
return {
86+
"batch_size": self.batch_size,
87+
"temperature": self.temperature,
88+
"top_k": self.top_k,
89+
"max_new_tokens": self.max_new_tokens,
90+
}
91+
92+
@classmethod
93+
def from_json(cls, d: dict) -> "LMInferenceConfig":
94+
return LMInferenceConfig(
95+
int(d["batch_size"]),
96+
float(d["temperature"]),
97+
int(d["top_k"]),
98+
int(d["max_new_tokens"]),
99+
)
100+
101+
77102
@dataclass(frozen=True)
78103
class RepairConfig(JsonSerializable):
79104
n_samples: int

src/realm/d4j.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
from functools import partial
1+
import csv
22
import itertools
3+
import multiprocessing as mp
4+
import os
5+
import subprocess
6+
from functools import partial
37
from os import PathLike
48
from pathlib import Path
5-
from typing import Dict, Iterator, List, NamedTuple
6-
from unidiff import PatchSet, PatchedFile
9+
from typing import Dict, Iterable, Iterator, List, NamedTuple, TypeVar
10+
11+
import git
12+
from unidiff import PatchedFile, PatchSet
713
from unidiff.patch import Line
14+
815
from realm import utils
9-
import subprocess
10-
import multiprocessing as mp
11-
import csv
12-
from pathlib import Path
13-
from typing import Dict, Iterable, Iterator, List, TypeVar
1416
from realm.utils import chunked
15-
import git
16-
import os
1717

1818
Metadata = Dict[str, List[Dict[str, str]]]
1919

0 commit comments

Comments
 (0)