Skip to content

Commit 3728c0d

Browse files
committed
remaining post-rebase fix
1 parent 0d450ae commit 3728c0d

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

hls4ml/converters/keras_v3/merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def handle(
4747
f'{in_tensors[0].shape} and {in_tensors[1].shape} for layer {layer.name}.'
4848
)
4949
assert all(len(t.shape) == 2 for t in in_tensors), msg
50-
assert in_tensors[0].shape[1] == in_tensors[1].shape[0], f'Input shape mismatch for layer {layer.name}.'
50+
assert in_tensors[0].shape[1] == in_tensors[1].shape[1], f'Input shape mismatch for layer {layer.name}.'
5151
class_name = 'Dot'
5252
op = 'dot1d'
5353
config['axes'] = layer.axes

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import re
22
import typing
3+
from collections.abc import Sequence
34
from copy import copy
45
from functools import reduce, singledispatch
56
from math import ceil, log2, prod
6-
from typing import Sequence
77
from warnings import warn
88

99
import numpy as np
@@ -313,18 +313,14 @@ def _(layer: Merge):
313313

314314
@_produce_kif.register
315315
def _(layer: EinsumDense):
316-
t_kernel = layer.attributes.attributes['weight'].data
317-
to_original_kernel = layer.attributes.attributes['to_original_kernel']
318-
kernel = to_original_kernel(t_kernel)
316+
kernel = layer.attributes.attributes['weight'].data
319317
_bias = layer.attributes.attributes['bias']
320318
eq = layer.attributes.attributes['equation']
321319
k_in, i_in, f_in = get_input_kifs(layer)[0]
322320
qint_in = QIntervalArray.from_kif(k_in, i_in, f_in)
323321
qint_out = einsum(eq, qint_in, kernel)
324322
if _bias is not None:
325-
t_bias = _bias.data
326-
bias = t_bias.transpose(layer.attributes.attributes['out_tpose_idxs'])
327-
qint_out = qint_out + bias
323+
qint_out = qint_out + _bias.data
328324
k, i, f = qint_out.to_kif()
329325
return k.astype(np.int8), i, f
330326

0 commit comments

Comments
 (0)