Releases: ML4ITS/mtad-gat-pytorch
v1.0.2
Fix
- Changed bias initialization to avoid NaN errors during training (d00c408)
Docs
- Added contributions list to README.md (5744087)
More info from @srigas regarding the fix:
Hello again, this is another fix for a common problem presented in the Issues of the mtad-gat-pytorch repository. The use_bias parameter of the GAT layers is set to True by default. However, the initialization of the biases is such, that the corresponding Tensors can sometimes fill up with nonsense (due to torch.empty), thus leading to NaN values during training. Setting the use_bias parameter to False solves the issue of NaNs coming up, but does not allow the user to include biases in their model's architecture. The proposed solution solves both problems, and is also the de-facto approach in other big projects (see the GATConv class from PyG, for example, where zeros(self.bias) is written in the reset_parameters function, indicating that an initialization of zero for the biases is acceptable).
v1.0.1
Fix
Thanks to @srigas for these contributions!
More info from @srigas regarding these fixes:
-
During prediction, a series of PerformanceWarning messages spams the console. This is due to the fact that in the prediction.py file, an empty dataframe is initialized and is then populated with columns and values within a for loop. This is considered a bad practice, hence the warnings that spam the console. Initializing a dictionary object and then casting it into a DataFrame solves the problem.
-
More importantly, during prediction, the wrong arguments are parsed in the predict.py file. More specifically, all model parameters are parsed from the args dictionary instead of the model_args dictionary. This means one of two things: either the user must specify all model parameters both during training and prediction through the console, or if they do not specify them during prediction, then an error does not occur if and only if the model has been trained with the default parameters. Replacing args with model_args in the specified region of the code solves this problem.