在用自己的数据集运行DG库中的MMD算法时,出现了矩阵形状不一样导致无法相加的问题。下图是github中MMD计算的代码:

我的数据,x1和x2的形状均为(32, 600, 64)。考虑到batch问题,我在计算时将addmm换为了addbmm,但是我发现,无论是用二维数据使用addmm还是三维数据使用addbmm,在计算时存在矩阵形状不一样导致无法相加的问题,各个张量的形状如下图所示:

可以看到,x2_norm.transpose(-2, -1)的形状与matmul(x1, x2.transpose(-2, -1))的形状不一致,二者是没办法做和相加的。麻烦大佬们看一下,是哪里有了问题?感谢
在用自己的数据集运行DG库中的MMD算法时,出现了矩阵形状不一样导致无法相加的问题。下图是github中MMD计算的代码:


我的数据,x1和x2的形状均为(32, 600, 64)。考虑到batch问题,我在计算时将
addmm换为了addbmm,但是我发现,无论是用二维数据使用addmm还是三维数据使用addbmm,在计算时存在矩阵形状不一样导致无法相加的问题,各个张量的形状如下图所示:可以看到,
x2_norm.transpose(-2, -1)的形状与matmul(x1, x2.transpose(-2, -1))的形状不一致,二者是没办法做和相加的。麻烦大佬们看一下,是哪里有了问题?感谢