Skip to content

Commit 1559b31

Browse files
committed
Add an adversarial notebook for ONNX
1 parent 066bf54 commit 1559b31

File tree

3 files changed

+329
-7
lines changed

3 files changed

+329
-7
lines changed
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "0",
6+
"metadata": {},
7+
"source": [
8+
"# Adversarial example using ONNX\n",
9+
"\n"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"id": "1",
15+
"metadata": {},
16+
"source": [
17+
"## Import the necessary packages and load data\n",
18+
"\n"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"id": "2",
25+
"metadata": {},
26+
"outputs": [],
27+
"source": [
28+
"from matplotlib import pyplot as plt\n",
29+
"import numpy as np\n",
30+
"import tensorflow as tf\n",
31+
"from tensorflow import keras\n",
32+
"import onnx\n",
33+
"from onnx import helper, TensorProto\n",
34+
"\n",
35+
"import gurobipy as gp\n",
36+
"\n",
37+
"from gurobi_ml import add_predictor_constr"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"id": "3",
44+
"metadata": {},
45+
"outputs": [],
46+
"source": [
47+
"(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()"
48+
]
49+
},
50+
{
51+
"cell_type": "markdown",
52+
"id": "4",
53+
"metadata": {},
54+
"source": [
55+
"We reshape and scale `x_train` and `x_test`.\n"
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": null,
61+
"id": "5",
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"x_train = tf.reshape(tf.cast(x_train, tf.float32) / 255.0, [-1, 28 * 28])\n",
66+
"x_test = tf.reshape(tf.cast(x_test, tf.float32) / 255.0, [-1, 28 * 28])"
67+
]
68+
},
69+
{
70+
"cell_type": "markdown",
71+
"id": "6",
72+
"metadata": {},
73+
"source": [
74+
"## Construct and train the neural network\n",
75+
"\n"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"id": "7",
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"nn = tf.keras.models.Sequential(\n",
86+
" [\n",
87+
" tf.keras.layers.InputLayer((28 * 28,)),\n",
88+
" tf.keras.layers.Dense(50, activation=\"relu\"),\n",
89+
" tf.keras.layers.Dense(50, activation=\"relu\"),\n",
90+
" tf.keras.layers.Dense(10), # logits\n",
91+
" ]\n",
92+
")"
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": null,
98+
"id": "8",
99+
"metadata": {},
100+
"outputs": [],
101+
"source": [
102+
"nn.compile(\n",
103+
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
104+
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
105+
" metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n",
106+
")"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": null,
112+
"id": "9",
113+
"metadata": {},
114+
"outputs": [],
115+
"source": [
116+
"nn.fit(\n",
117+
" x_train,\n",
118+
" y_train,\n",
119+
" epochs=3,\n",
120+
" validation_data=(x_test, y_test),\n",
121+
")"
122+
]
123+
},
124+
{
125+
"cell_type": "markdown",
126+
"id": "10",
127+
"metadata": {},
128+
"source": [
129+
"Convert the trained Keras model to an ONNX MLP\n"
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": null,
135+
"id": "11",
136+
"metadata": {},
137+
"outputs": [],
138+
"source": [
139+
"def keras_dense_layers_to_onnx(model):\n",
140+
" # Extract dense layers weights/bias and activation\n",
141+
" layers = []\n",
142+
" in_dim = None\n",
143+
" for layer in model.layers:\n",
144+
" if isinstance(layer, tf.keras.layers.InputLayer):\n",
145+
" try:\n",
146+
" in_dim = layer.input_shape[-1]\n",
147+
" except Exception:\n",
148+
" pass\n",
149+
" elif isinstance(layer, tf.keras.layers.Dense):\n",
150+
" W, b = layer.get_weights()\n",
151+
" act = layer.get_config().get(\"activation\", \"linear\")\n",
152+
" layers.append((W.astype(np.float32), b.astype(np.float32), act))\n",
153+
"\n",
154+
" # Build ONNX graph from collected layers\n",
155+
" n_in = in_dim or layers[0][0].shape[0]\n",
156+
" X = helper.make_tensor_value_info(\"X\", TensorProto.FLOAT, [None, n_in])\n",
157+
"\n",
158+
" last = \"X\"\n",
159+
" inits = []\n",
160+
" nodes = []\n",
161+
" for i, (W, b, act) in enumerate(layers):\n",
162+
" W_name = f\"W{i + 1}\"\n",
163+
" b_name = f\"b{i + 1}\"\n",
164+
" # Gemm with transB=1 realizes (last @ W + b) when B is W.T\n",
165+
" inits.append(\n",
166+
" helper.make_tensor(W_name, TensorProto.FLOAT, W.T.shape, W.T.flatten())\n",
167+
" )\n",
168+
" inits.append(helper.make_tensor(b_name, TensorProto.FLOAT, b.shape, b))\n",
169+
" out_name = f\"H{i + 1}\"\n",
170+
" nodes.append(\n",
171+
" helper.make_node(\n",
172+
" \"Gemm\",\n",
173+
" inputs=[last, W_name, b_name],\n",
174+
" outputs=[out_name],\n",
175+
" name=f\"gemm{i + 1}\",\n",
176+
" transB=1,\n",
177+
" )\n",
178+
" )\n",
179+
" last = out_name\n",
180+
" if act == \"relu\":\n",
181+
" act_name = f\"A{i + 1}\"\n",
182+
" nodes.append(\n",
183+
" helper.make_node(\n",
184+
" \"Relu\", inputs=[last], outputs=[act_name], name=f\"relu{i + 1}\"\n",
185+
" )\n",
186+
" )\n",
187+
" last = act_name\n",
188+
"\n",
189+
" # Connect final tensor to a named output via Identity\n",
190+
" n_out = layers[-1][1].shape[0]\n",
191+
" nodes.append(\n",
192+
" helper.make_node(\"Identity\", inputs=[last], outputs=[\"Y\"], name=\"output\")\n",
193+
" )\n",
194+
" Y = helper.make_tensor_value_info(\"Y\", TensorProto.FLOAT, [None, n_out])\n",
195+
" graph = helper.make_graph(\n",
196+
" nodes=nodes, name=\"KerasMLP\", inputs=[X], outputs=[Y], initializer=inits\n",
197+
" )\n",
198+
" model = helper.make_model(graph)\n",
199+
" onnx.checker.check_model(model)\n",
200+
" return model\n",
201+
"\n",
202+
"\n",
203+
"onnx_model = keras_dense_layers_to_onnx(nn)"
204+
]
205+
},
206+
{
207+
"cell_type": "markdown",
208+
"id": "12",
209+
"metadata": {},
210+
"source": [
211+
"## Build optimization model\n",
212+
"\n",
213+
"Now we turn to building the optimization model.\n",
214+
"\n",
215+
"We choose a training example and follow the same steps as the Keras example.\n"
216+
]
217+
},
218+
{
219+
"cell_type": "code",
220+
"execution_count": null,
221+
"id": "13",
222+
"metadata": {},
223+
"outputs": [],
224+
"source": [
225+
"example = x_train[18, :]\n",
226+
"plt.imshow(tf.reshape(example, [28, 28]), cmap=\"gray\")\n",
227+
"ex_prob = nn.predict(tf.reshape(example, (1, -1)))\n",
228+
"sorted_labels = tf.argsort(ex_prob)[0]\n",
229+
"right_label = sorted_labels[-1]\n",
230+
"wrong_label = sorted_labels[-2]\n",
231+
"print(\n",
232+
" f\"Original classified as {int(right_label)}; target misclassify as {int(wrong_label)}\"\n",
233+
")"
234+
]
235+
},
236+
{
237+
"cell_type": "code",
238+
"execution_count": null,
239+
"id": "14",
240+
"metadata": {},
241+
"outputs": [],
242+
"source": [
243+
"m = gp.Model()\n",
244+
"delta = 5\n",
245+
"\n",
246+
"x = m.addMVar(example.numpy().shape, lb=0.0, ub=1.0, name=\"x\")\n",
247+
"y = m.addMVar(ex_prob.shape, lb=-gp.GRB.INFINITY, name=\"y\")\n",
248+
"\n",
249+
"abs_diff = m.addMVar(example.numpy().shape, lb=0, ub=1, name=\"abs_diff\")\n",
250+
"\n",
251+
"m.setObjective(y[0, wrong_label] - y[0, right_label], gp.GRB.MAXIMIZE)\n",
252+
"\n",
253+
"# Bound on the distance to example in norm-1\n",
254+
"m.addConstr(abs_diff >= x - example.numpy())\n",
255+
"m.addConstr(abs_diff >= -x + example.numpy())\n",
256+
"m.addConstr(abs_diff.sum() <= delta)\n",
257+
"\n",
258+
"pred_constr = add_predictor_constr(m, onnx_model, x, y)\n",
259+
"\n",
260+
"pred_constr.print_stats()"
261+
]
262+
},
263+
{
264+
"cell_type": "code",
265+
"execution_count": null,
266+
"id": "15",
267+
"metadata": {},
268+
"outputs": [],
269+
"source": [
270+
"m.Params.BestBdStop = 0.0\n",
271+
"m.Params.BestObjStop = 0.0\n",
272+
"m.optimize()"
273+
]
274+
},
275+
{
276+
"cell_type": "markdown",
277+
"id": "16",
278+
"metadata": {},
279+
"source": [
280+
"Finally, display the adversarial example if one was found.\n"
281+
]
282+
},
283+
{
284+
"cell_type": "code",
285+
"execution_count": null,
286+
"id": "17",
287+
"metadata": {},
288+
"outputs": [],
289+
"source": [
290+
"if m.SolCount and m.ObjVal > 0.0:\n",
291+
" plt.imshow(x.X.reshape((28, 28)), cmap=\"gray\")\n",
292+
" label = tf.math.argmax(nn.predict(tf.reshape(x.X, (1, -1))), axis=1)\n",
293+
" print(f\"Solution is classified as {label.numpy()[0]}\")\n",
294+
"else:\n",
295+
" print(\"No counter example exists in neighborhood.\")"
296+
]
297+
}
298+
],
299+
"metadata": {
300+
"kernelspec": {
301+
"display_name": "Python 3 (ipykernel)",
302+
"language": "python",
303+
"name": "python3"
304+
},
305+
"language_info": {
306+
"codemirror_mode": {
307+
"name": "ipython",
308+
"version": 3
309+
},
310+
"file_extension": ".py",
311+
"mimetype": "text/x-python",
312+
"name": "python",
313+
"nbconvert_exporter": "python",
314+
"pygments_lexer": "ipython3",
315+
"version": "3.13.3"
316+
},
317+
"license": {
318+
"full_text": "# Copyright © 2025 Gurobi Optimization, LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =============================================================================="
319+
}
320+
},
321+
"nbformat": 4,
322+
"nbformat_minor": 5
323+
}

src/gurobi_ml/onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright © 2023 Gurobi Optimization, LLC
1+
# Copyright © 2025 Gurobi Optimization, LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

src/gurobi_ml/onnx/onnx_model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright © 2023 Gurobi Optimization, LLC
1+
# Copyright © 2025 Gurobi Optimization, LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -22,7 +22,6 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import List, Optional
2625

2726
import numpy as np
2827

@@ -82,13 +81,13 @@ def __init__(self, gp_model, predictor, input_vars, output_vars=None, **kwargs):
8281
if not isinstance(predictor, onnx.ModelProto):
8382
raise NoModel(predictor, "Expected an onnx.ModelProto model")
8483

85-
self._layers_spec: List[_ONNXLayer] = self._parse_mlp(predictor)
84+
self._layers_spec: list[_ONNXLayer] = self._parse_mlp(predictor)
8685
if not self._layers_spec:
8786
raise NoModel(predictor, "Empty or unsupported ONNX graph")
8887

8988
super().__init__(gp_model, predictor, input_vars, output_vars, **kwargs)
9089

91-
def _parse_mlp(self, model: "onnx.ModelProto") -> List[_ONNXLayer]:
90+
def _parse_mlp(self, model: onnx.ModelProto) -> list[_ONNXLayer]:
9291
"""Parse a limited subset of ONNX graphs representing MLPs.
9392
9493
We support sequences of: Gemm -> (Relu)? -> Gemm -> (Relu)? ...
@@ -114,8 +113,8 @@ def _get_attr(node, name, default=None):
114113
return default
115114

116115
# Iterate nodes gathering dense layers and relus
117-
layers: List[_ONNXLayer] = []
118-
pending_activation: Optional[str] = None
116+
layers: list[_ONNXLayer] = []
117+
pending_activation: str | None = None
119118

120119
for node in graph.node:
121120
op = node.op_type

0 commit comments

Comments
 (0)