@@ -71,6 +71,20 @@ struct Shape {
71
71
}
72
72
};
73
73
74
+ inline std::ostream& operator <<(std::ostream& os, const Shape& shape)
75
+ {
76
+ int size = shape.rank ;
77
+ os << " Shape: [" ;
78
+ for (int i=0 ;i<size-1 ;i++){
79
+ os << shape.data [i] << " ," ;
80
+ }
81
+ if ( size != 0 ) {
82
+ os << shape.data [size-1 ];
83
+ }
84
+ os << " ]" ;
85
+ return os;
86
+ }
87
+
74
88
/* *
75
89
* @brief Returns the number of elements in a tensor with the given shape,
76
90
* which is equal to the product of the dimensions.
@@ -210,30 +224,30 @@ enum NumType {
210
224
/* *
211
225
* @brief Returns the number of bytes of a number type.
212
226
*/
213
- inline size_t sizeBytes (const NumType &type) {
227
+ inline size_t sizeBytes (const NumType &type, int numElements = 1 ) {
214
228
switch (type) {
215
229
case kf16:
216
- return sizeof (half);
230
+ return sizeof (half) * numElements ;
217
231
case kf32:
218
- return sizeof (float );
232
+ return sizeof (float ) * numElements ;
219
233
case kf64:
220
- return sizeof (double );
234
+ return sizeof (double ) * numElements ;
221
235
case ki8:
222
- return sizeof (int8_t );
236
+ return sizeof (uint32_t ) * ((numElements + 3 ) / 4 );
223
237
case ki16:
224
- return sizeof (int16_t );
238
+ return sizeof (uint32_t ) * ((numElements + 1 ) / 2 );
225
239
case ki32:
226
- return sizeof (int32_t );
240
+ return sizeof (int32_t ) * numElements ;
227
241
case ki64:
228
- return sizeof (int64_t );
242
+ return sizeof (int64_t ) * numElements ;
229
243
case ku8:
230
- return sizeof (uint8_t );
244
+ return sizeof (uint32_t ) * ((numElements + 3 ) / 4 );
231
245
case ku16:
232
- return sizeof (uint16_t );
246
+ return sizeof (uint32_t ) * ((numElements + 1 ) / 2 );
233
247
case ku32:
234
- return sizeof (uint32_t );
248
+ return sizeof (uint32_t ) * numElements ;
235
249
case ku64:
236
- return sizeof (uint64_t );
250
+ return sizeof (uint64_t ) * numElements ;
237
251
default :
238
252
LOG (kDefLog , kError , " Invalid NumType in size calculation." );
239
253
return 0 ;
@@ -697,7 +711,7 @@ inline Tensor createTensor(TensorPool &pool, WGPUDevice &device,
697
711
WGPUBufferUsage_CopySrc) {
698
712
LOG (kDefLog , kTrace , " Creating tensor" );
699
713
size_t numElements = size (shape);
700
- size_t size = sizeBytes (dtype) * numElements;
714
+ size_t size = sizeBytes (dtype, numElements) ;
701
715
WGPUBufferDescriptor bufferDesc = {
702
716
.label = {.data = nullptr , .length = 0 },
703
717
.usage = usage,
@@ -828,7 +842,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
828
842
// unpacking
829
843
packed[idx] |= (static_cast <uint8_t >(data[i]) << shift);
830
844
}
831
- return createTensor (ctx, shape, ki32, packed.data ());
845
+ Tensor tensor = createTensor (ctx, shape, ki8);
846
+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
847
+ tensor.data .size );
848
+ return tensor;
832
849
}
833
850
834
851
// Overload for int16_t: pack two 16‑bit ints into one 32‑bit integer
@@ -843,7 +860,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
843
860
size_t shift = (i % 2 ) * 16 ;
844
861
packed[idx] |= (static_cast <uint16_t >(data[i]) << shift);
845
862
}
846
- return createTensor (ctx, shape, ki32, packed.data ());
863
+ Tensor tensor = createTensor (ctx, shape, ki16);
864
+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
865
+ tensor.data .size );
866
+ return tensor;
847
867
}
848
868
849
869
// Overload for int64_t: pack each 64‑bit int into two 32‑bit integers
@@ -857,7 +877,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
857
877
packed[2 * i] = static_cast <int32_t >(val & 0xFFFFFFFF );
858
878
packed[2 * i + 1 ] = static_cast <int32_t >((val >> 32 ) & 0xFFFFFFFF );
859
879
}
860
- return createTensor (ctx, shape, ki32, packed.data ());
880
+ Tensor tensor = createTensor (ctx, shape, ki64);
881
+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
882
+ tensor.data .size );
883
+ return tensor;
861
884
}
862
885
863
886
inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
@@ -885,7 +908,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
885
908
size_t shift = (i % 4 ) * 8 ;
886
909
packed[idx] |= (static_cast <uint32_t >(data[i]) << shift);
887
910
}
888
- return createTensor (ctx, shape, ku32, packed.data ());
911
+ Tensor tensor = createTensor (ctx, shape, ku8);
912
+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
913
+ tensor.data .size );
914
+ return tensor;
889
915
}
890
916
891
917
// Overload for uint16_t: pack two 16‑bit integers into one 32‑bit unsigned
@@ -901,7 +927,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
901
927
size_t shift = (i % 2 ) * 16 ;
902
928
packed[idx] |= (static_cast <uint32_t >(data[i]) << shift);
903
929
}
904
- return createTensor (ctx, shape, ku32, packed.data ());
930
+ Tensor tensor = createTensor (ctx, shape, ku16);
931
+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
932
+ tensor.data .size );
933
+ return tensor;
905
934
}
906
935
907
936
// Overload for uint64_t: pack each 64‑bit integer into two 32‑bit unsigned
@@ -916,7 +945,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
916
945
packed[2 * i] = static_cast <uint32_t >(val & 0xFFFFFFFF );
917
946
packed[2 * i + 1 ] = static_cast <uint32_t >(val >> 32 );
918
947
}
919
- return createTensor (ctx, shape, ku32, packed.data ());
948
+ Tensor tensor = createTensor (ctx, shape, ku64);
949
+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
950
+ tensor.data .size );
951
+ return tensor;
920
952
}
921
953
922
954
/* *
@@ -1987,7 +2019,7 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, NumType dtype, void *output,
1987
2019
case kf32:
1988
2020
case ku32:
1989
2021
case ki32: {
1990
- size_t byteSize = numElements * sizeBytes (dtype);
2022
+ size_t byteSize = sizeBytes (dtype, numElements );
1991
2023
toCPU (ctx, buffer, output, byteSize, sourceOffset);
1992
2024
break ;
1993
2025
}
0 commit comments