Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

运行MMD时矩阵不一致问题的报错。 #400

Open
sunsenheping opened this issue Jul 18, 2023 · 5 comments
Open

运行MMD时矩阵不一致问题的报错。 #400

sunsenheping opened this issue Jul 18, 2023 · 5 comments

Comments

@sunsenheping
Copy link

在用自己的数据集运行DG库中的MMD算法时,出现了矩阵形状不一样导致无法相加的问题。下图是github中MMD计算的代码:
b71752db2290bfd5703138368d41c05
我的数据,x1和x2的形状均为(32, 600, 64)。考虑到batch问题,我在计算时将addmm换为了addbmm,但是我发现,无论是用二维数据使用addmm还是三维数据使用addbmm,在计算时存在矩阵形状不一样导致无法相加的问题,各个张量的形状如下图所示:
66e83a0c316846b11decc46f064f3c4
可以看到,x2_norm.transpose(-2, -1)的形状与matmul(x1, x2.transpose(-2, -1))的形状不一致,二者是没办法做和相加的。麻烦大佬们看一下,是哪里有了问题?感谢

@jindongwang
Copy link
Owner

@lw0517 有时间来看一下。

@lw0517
Copy link
Collaborator

lw0517 commented Jul 18, 2023

在用自己的数据集运行DG库中的MMD算法时,出现了矩阵形状不一样导致无法相加的问题。下图是github中MMD计算的代码: b71752db2290bfd5703138368d41c05 我的数据,x1和x2的形状均为(32, 600, 64)。考虑到batch问题,我在计算时将addmm换为了addbmm,但是我发现,无论是用二维数据使用addmm还是三维数据使用addbmm,在计算时存在矩阵形状不一样导致无法相加的问题,各个张量的形状如下图所示: 66e83a0c316846b11decc46f064f3c4 可以看到,x2_norm.transpose(-2, -1)的形状与matmul(x1, x2.transpose(-2, -1))的形状不一致,二者是没办法做和相加的。麻烦大佬们看一下,是哪里有了问题?感谢

建议的做法是将三维的转化成二维的进行MMD距离计算。如果用torch.addbmm,最好参考一下https://pytorch.org/docs/1.10/generated/torch.addmm.html?highlight=addmm#torch.addmm, 可以看出来维度是不一致的。

@sunsenheping
Copy link
Author

在用自己的数据集运行DG库中的MMD算法时,出现了矩阵形状不一样导致无法相加的问题。下图是github中MMD计算的代码: b71752db2290bfd5703138368d41c05 我的数据,x1和x2的形状均为(32, 600, 64)。考虑到batch问题,我在计算时将addmm换为了addbmm,但是我发现,无论是用二维数据使用addmm还是三维数据使用addbmm,在计算时存在矩阵形状不一样导致无法相加的问题,各个张量的形状如下图所示: 66e83a0c316846b11decc46f064f3c4 可以看到,x2_norm.transpose(-2, -1)的形状与matmul(x1, x2.transpose(-2, -1))的形状不一致,二者是没办法做和相加的。麻烦大佬们看一下,是哪里有了问题?感谢

建议的做法是将三维的转化成二维的进行MMD距离计算。如果用torch.addbmm,最好参考一下https://pytorch.org/docs/1.10/generated/torch.addmm.html?highlight=addmm#torch.addmm, 可以看出来维度是不一致的。

问题是,我用二维矩阵也试过,按照你们MMD的代码,还是存在形状不一致,addmm无法进行相加的问题的。你也可以推导一下。
image

@sunsenheping
Copy link
Author

如果x2_norm.transpose(-2, -1)的形状与matmul(x1, x2.transpose(-2, -1))的形状不一致,二者是没办法做和相加的,那么你们现在的MMD代码也就无法跑通呀。

@jindongwang
Copy link
Owner

我们是按照batch来算的,一个batch内大家形状是一样的,不明白为什么会出现形状不一样的问题。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants