Skip to content

Commit 4e0d4ae

Browse files
committed
Add more test sequences, scatter plot visualization
1 parent 83fca0e commit 4e0d4ae

File tree

3 files changed

+68
-20
lines changed

3 files changed

+68
-20
lines changed

Pipfile

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ name = "pypi"
66
[packages]
77
transformers = {extras = ["torch"], version = "*"}
88
scipy = "*"
9+
matplotlib = "*"
910

1011
[dev-packages]
1112

run_sentiment_classifier.py

+53-20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
13
# https://stackoverflow.com/questions/7370801/how-to-measure-elapsed-time-in-python
24
from timeit import default_timer as timer
35

@@ -8,6 +10,12 @@
810
from utils import preprocess, download_label_mapping, output_vector_to_labels
911

1012

13+
def read_test_sequences(path: str):
14+
with open(path, 'r') as f:
15+
sequences = [x.rstrip() for x in f.readlines()]
16+
return sequences
17+
18+
1119
def run_model(model, tokenized_input):
1220
output = model(**tokenized_input)
1321
return output_vector_to_labels(output, download_label_mapping())
@@ -19,37 +27,62 @@ def check_inference_time(model, tokenized_input):
1927
elapsed_time = timer()-t
2028
return elapsed_time
2129

30+
2231
if __name__ == "__main__":
2332
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment", torchscript=True)
2433
clf = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment", torchscript=True)
2534

26-
input_texts = [
27-
"Hello world",
28-
"Happy birthday",
29-
"I don't think it's gonna work",
30-
"I enjoy natural language understanding"
31-
]
35+
n_experiments = 5
36+
input_texts = [preprocess(x) for x in read_test_sequences("test_sequences.txt")]
3237

33-
# 1. Vanilla
38+
# 1. Eager
39+
eager_measurements = np.zeros((n_experiments, len(input_texts)))
3440
tokenized_inputs = [tokenizer(x, return_tensors='pt') for x in input_texts]
35-
outputs = [run_model(clf, x) for x in tokenized_inputs]
36-
output_times = [check_inference_time(clf, x) for x in tokenized_inputs]
37-
38-
for inp, out in zip(input_texts, outputs):
39-
print(inp, out)
40-
41-
print(output_times)
4241

43-
print("")
42+
for i in range(n_experiments):
43+
# outputs = [run_model(clf, x) for x in tokenized_inputs]
44+
eager_measurements[i] = [check_inference_time(clf, x) for x in tokenized_inputs]
45+
# for inp, out in zip(input_texts, outputs):
46+
# print(inp, '\n', out, '\n')
47+
# print(output_times)
4448

49+
4550
# 2. TorchScript (JIT)
51+
script_measurements = np.zeros((n_experiments, len(input_texts)))
4652
tokenized_inputs = [tokenizer(x, return_tensors='pt') for x in input_texts]
4753
traced_model = torch.jit.trace(clf, (tokenized_inputs[0]['input_ids'], tokenized_inputs[0]['attention_mask']))
48-
outputs = [run_model(traced_model, x) for x in tokenized_inputs]
49-
output_times = [check_inference_time(traced_model, x) for x in tokenized_inputs]
5054
# torch.jit.save(traced_model, "traced_twitter_roberta_base_sentiment.pt")
5155
# loaded_model = torch.jit.load("traced_twitter_roberta_base_sentiment.pt")
52-
for inp, out in zip(input_texts, outputs):
53-
print(inp, out)
5456

55-
print(output_times)
57+
for i in range(n_experiments):
58+
# outputs = [run_model(traced_model, x) for x in tokenized_inputs]
59+
script_measurements[i] = [check_inference_time(traced_model, x) for x in tokenized_inputs]
60+
# for inp, out in zip(input_texts, outputs):
61+
# print(inp, '\n', out, '\n')
62+
# print(output_times)
63+
64+
print(eager_measurements)
65+
print(script_measurements)
66+
67+
# Box Plot
68+
69+
eager_avgs = np.mean(eager_measurements, axis=0)
70+
script_avgs = np.mean(script_measurements, axis=0)
71+
print(eager_avgs)
72+
print(script_avgs)
73+
74+
# Scatter Plot
75+
76+
indices = np.tile(np.arange(len(input_texts)), n_experiments)
77+
eager_measurements = eager_measurements.flatten()
78+
script_measurements = script_measurements.flatten()
79+
print(indices)
80+
print(eager_measurements)
81+
82+
plt.style.use('seaborn')
83+
plt.scatter(indices, eager_measurements, label='Eager mode')
84+
plt.scatter(indices, script_measurements, label='Script mode')
85+
plt.xlabel('Sequence ID')
86+
plt.ylabel('Inference time [s]')
87+
plt.legend()
88+
plt.show()

test_sequences.txt

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Happy birthday
2+
Taming transformers is a fun paper to read.
3+
Really happy to see this. Congratulations John :)
4+
Honored to be in such good company for the Bay Area’s Best Places to Work 2016
5+
I've always been inspired by how my friend @ramsri_goutham applies NLP to creative real-world use cases.
6+
At @gridai_ & @pytorchlightnin we strongly believe #AI should be #inclusive and celebrate #diversity
7+
We're excited to launch the third Habitat Challenge at the Embodied AI workshop with 15 research & academic institutions.
8+
I don't think it's gonna work
9+
Sometimes, You gotta hate #Windows updates.
10+
Friend requests on @facebook I’m still confused
11+
@facebook… for f***s sake… Can’t find a post posted 3mins ago due to your stupid “non chronological” timeline
12+
Netflix's Chaos Monkey tool, but instead of randomly killing containers it randomly cancels meetings on my calendar
13+
All of these things are true: 1. Tech is terrible at hiring and interviewing. 2. Tech is terrible at defining different areas of data work.
14+
I’m so frustrated that many privileged people are ignoring racial/ethnic injustices of Covid-19 deaths. Widespread apathy is making these grave disparities more severe in younger adults. Generational damage is destroying families.

0 commit comments

Comments
 (0)