Skip to content

Commit 297ef1f

Browse files
committed
Adding tensorflow backend.
1 parent 8e5ecfb commit 297ef1f

File tree

4 files changed

+415
-4
lines changed

4 files changed

+415
-4
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
pytoch native backend for dlrm
3+
"""
4+
# pylint: disable=unused-argument,missing-docstring
5+
import torch # currently supports pytorch1.0
6+
import backend
7+
# from dlrm_s_pytorch import DLRM_Net
8+
import tensorflow as tf
9+
from tf_dlrm import logits_fn, rand_features_np
10+
import numpy as np
11+
import collections
12+
from typing import Dict, Any
13+
import sys
14+
15+
class BackendTF(backend.Backend):
16+
def __init__(self, dim_embed, vocab_sizes, mlp_bottom, mlp_top):
17+
super(BackendTF, self).__init__()
18+
self.sess = None
19+
self.model = None
20+
self.params = collections.defaultdict()
21+
22+
self.params["dim_embed"] = dim_embed
23+
self.params["vocab_sizes"] = vocab_sizes.tolist()
24+
25+
self.params["mlp_bottom"] = mlp_bottom.tolist()
26+
self.params["mlp_top"] = mlp_top.tolist()
27+
28+
self.params["num_dense_features"] = self.params["mlp_bottom"][0]
29+
self.params["num_sparse_features"] = len(self.params["vocab_sizes"])
30+
self.params["num_tables_in_ec"] = 26
31+
32+
self.params["learning_rate"] = 0.01
33+
self.params["opt_skip_gather"] = True
34+
35+
self.params["is_training"] = True
36+
37+
def version(self):
38+
return tf.__version__
39+
40+
def name(self):
41+
return "tf-dlrm"
42+
43+
def load(self, model_path, inputs=None, outputs=None):
44+
# debug prints
45+
# print(model_path, inputs, outputs)
46+
47+
self.model_path = model_path
48+
49+
num_d = self.params["num_dense_features"]
50+
num_s = self.params["num_sparse_features"]
51+
minsize = min(self.params["vocab_sizes"])
52+
print("stat: ", num_d, num_s, minsize)
53+
54+
self.graph = tf.Graph()
55+
56+
with self.graph.as_default():
57+
58+
features_int_np, features_cat_np = rand_features_np(1, num_d, num_s, minsize)
59+
60+
features_int = tf.placeholder(tf.float32, [None, num_d], name="ph_1")
61+
features_cat = tf.placeholder(tf.int32, [None, num_s], name="ph_2")
62+
63+
preds = logits_fn(features_int, features_cat, self.params)
64+
preds = tf.identity(preds, name="preds")
65+
66+
init_op = tf.compat.v1.global_variables_initializer()
67+
68+
self.sess = tf.compat.v1.Session(graph=self.graph)
69+
70+
self.sess.run(init_op)
71+
self.sess.run(preds, feed_dict = {features_int : features_int_np, features_cat : features_cat_np} )
72+
73+
self.params["is_training"] = False
74+
75+
print("load() finished ...")
76+
77+
return self
78+
79+
def predict(self, batch_dense_X, batch_lS_o, batch_lS_i):
80+
81+
# features from input to this function
82+
# torch -> numpy -> tf -> numpy -> torch
83+
84+
# dense features
85+
pytorch_tensor = batch_dense_X.detach().cpu()
86+
np_tensor_int = pytorch_tensor.numpy()
87+
88+
# sparse features
89+
pytorch_tensor2 = batch_lS_i.detach().cpu()
90+
np_tensor2 = pytorch_tensor2.numpy()
91+
np_tensor_cat = np.transpose(np_tensor2)
92+
93+
# print_op_preds = tf.print(estim.predictions, output_stream=sys.stdout)
94+
95+
out_operation = self.graph.get_operation_by_name('preds')
96+
97+
ph_1 = self.graph.get_tensor_by_name('ph_1:0')
98+
ph_2 = self.graph.get_tensor_by_name('ph_2:0')
99+
100+
np_tensor_out = out_operation.outputs[0].eval(session=self.sess, feed_dict = {ph_1 : np_tensor_int, ph_2 : np_tensor_cat})
101+
102+
print("1st output element: ", np_tensor_out[:1])
103+
output = torch.from_numpy(np_tensor_out)
104+
return output

recommendation/dlrm/pytorch/python/main.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,22 @@
9393
"model": "dlrm",
9494
"max-batchsize": 2048,
9595
},
96-
97-
96+
"tf_dlrm-kaggle-tensorflow": {
97+
"dataset": "kaggle",
98+
"inputs": "continuous and categorical features",
99+
"outputs": "probability",
100+
"backend": "tensorflow",
101+
"model": "tf_dlrm",
102+
"max-batchsize": 128,
103+
},
104+
"tf_dlrm-terabyte-tensorflow": {
105+
"dataset": "terabyte",
106+
"inputs": "continuous and categorical features",
107+
"outputs": "probability",
108+
"backend": "tensorflow",
109+
"model": "tf_dlrm",
110+
"max-batchsize": 2048,
111+
},
98112
}
99113

100114
SCENARIO_MAP = {
@@ -253,6 +267,39 @@ def get_backend(backend, dataset, max_ind_range, data_sub_sample_rate, use_gpu):
253267
else:
254268
raise ValueError("only kaggle|terabyte dataset options are supported")
255269

270+
elif backend == "tensorflow":
271+
from backend_tf import BackendTF
272+
# NOTE: pass model parameters here, the following options are available
273+
if dataset == "kaggle":
274+
# 1. Criteo Kaggle Display Advertisement Challenge Dataset (see ./bench/dlrm_s_criteo_kaggle.sh)
275+
backend = BackendTF(
276+
dim_embed=16,
277+
vocab_sizes=np.array([1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572]),
278+
mlp_bottom=np.array([13,512,256,64,16]),
279+
mlp_top=np.array([367,512,256,1]),
280+
)
281+
elif dataset == "terabyte":
282+
if max_ind_range == 10000000:
283+
# 2. Criteo Terabyte (see ./bench/dlrm_s_criteo_terabyte.sh [--sub-sample=0.875] --max-in-range=10000000)
284+
backend = BackendTF(
285+
dim_embed=64,
286+
vocab_sizes=np.array([9980333,36084,17217,7378,20134,3,7112,1442,61, 9758201,1333352,313829,10,2208,11156,122,4,970,14, 9994222, 7267859, 9946608,415421,12420,101, 36]),
287+
mlp_bottom=np.array([13,512,256,64]),
288+
mlp_top=np.array([415,512,512,256,1]),
289+
)
290+
elif max_ind_range == 40000000:
291+
# 3. Criteo Terabyte MLPerf training (see ./bench/run_and_time.sh --max-in-range=40000000)
292+
backend = BackendTF(
293+
dim_embed=128,
294+
vocab_sizes=np.array([39884406,39043,17289,7420,20263,3,7120,1543,63,38532951,2953546,403346,10,2208,11938,155,4,976,14,39979771,25641295,39664984,585935,12972,108,36]),
295+
mlp_bottom=np.array([13,512,256,128]),
296+
mlp_top=np.array([479,1024,1024,512,256,1]),
297+
)
298+
else:
299+
raise ValueError("only --max-in-range 10M or 40M is supported")
300+
else:
301+
raise ValueError("only kaggle|terabyte dataset options are supported")
302+
256303
else:
257304
raise ValueError("unknown backend: " + backend)
258305
return backend

0 commit comments

Comments
 (0)