-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
7 implement ot #47
7 implement ot #47
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I've left a few comments and happy to discuss further.
Re. points you made in opening comment:
The PyTest tests can be run as they are but further otdd configurations could be added into the metrics.yaml file - currently there are only two options there and there are a lot of parameters to experiment with
I don't think we should worry too much about testing all param combinations, we should be able to assume that otdd has been tested. We just need to sanity check we're calling it correctly and I think we've done that
When I was testing I experienced the otdd tests failing inconsistently for the otdd_naive_upperbound test parameters. After setting a random seed when initialising the OTDD object the tests no longer failed... however, I am not 100% sure if this is a definite fix.
ok seems reasonable. have made a specific comment on the seeding
The calculate_metrics.py script also needs testing
Agreed, up to you whether you want to create a dedicated test_calculate_metrics.py
or test locally
Further safeguarding to prevent metrics being calculated on Apple M1 cpu
agree, have made some comments
The parameters to be set for OTDD - will the exact calculation be too computationally intensive
very good point. Let's try it out on Baskerville and make some adjustments if necessary. can be done outside of this PR
Do we want any other implementations of OT? If so, I envisage them as being set up as their own class rather than an option in the OTDD class
maybe and yes I think a separate class would be appropriate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks really good and a huge improvement on the previous approach to the metrics. I've suggested some fairly minor changes
.gitignore
Outdated
# Slurm outputs | ||
slurm_logs/ | ||
|
||
.DS_Store |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With #46 merged into develop there's going to be a merge conflict on this file (and maybe some others) - worth doing git merge and resolving it as part of this PR
Also, agreed re OT requiring a separate implementation to OTDD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. The merge needs to be done and resolved, and I have some additional comments (one of which I really should have thought of last time!). Overall looks better again!
# When exact calculations are used then this will be zero | ||
# Approximate methods may be non-zero and these are checked against the known value | ||
# for the seed | ||
def test_cifar_otdd_same(metrics_config: dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still getting a failure on both tests even with a fresh installation of everything, so we may need to settle for the 'within a range' test for the OTDD tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, the next update will change the tests to check for 'closeness' rather than equality
def __init__(self, seed: int): | ||
# Kernel dictionary | ||
self.__MMD_KERNEL_DICT = { | ||
"rbf": metrics.pairwise.rbf_kernel, | ||
"laplace": metrics.pairwise.laplacian_kernel, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing super().__init__(seed)
call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added this in though thought the mmd calculations were deterministic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh true! My bad!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, nice work. Just very minor points and then the issue (as discussed on slack) with checking values are close rather than exact matches.
tests/test_data.py
Outdated
@@ -367,11 +367,19 @@ def test_get_AB_data(): | |||
train_data_b, val_data_b = dmpair.get_B_data() | |||
|
|||
# a train | |||
_compare_dataloader_to_tensor(dl=dmpair.A.train_dataloader(), data=train_data_a) | |||
_compare_dataloader_to_tensor( | |||
dl=dmpair.A.train_dataloader(), data=torch.tensor(train_data_a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.from_numpy
rather than torch.tensor
- from_numpy
ensures uses the same memory https://pytorch.org/docs/stable/generated/torch.from_numpy.html
platform.processor() == "arm", | ||
reason="These tests should not be run on Apple M1 devices", | ||
) | ||
def test_cifar_otdd_different(metrics_config: dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to Phil's comment, and as discussed on slack, different vals on my laptop vs azure - closeness probably best here
FAILED tests/test_otdd.py::test_cifar_otdd_different - assert 98.62165832519531 == 98.62129211425781
The updates to the otdd tests are now ready to be tested. I wanted to write the tests so that only one assert statement would be made for each test - it would be best practice to implement them that way so that all the metric configs get tested. Currently if the first one fails, then the others don't run and I thought it would be useful when debugging to know if it's only one that has failed or all of them, for example. I thought this should be possible by parametrising the tests (https://docs.pytest.org/en/7.1.x/example/parametrize.html). I could write the tests in pseudo-code and then had an L&D day on Friday trying to get this working in Pytest, however it doesn’t seem possible to use fixtures directly as arguments, e.g. I wanted to write something like the following, but it throws an error due to metrics_config being a fixture: I couldn’t find a way around it so I’ve implemented a work-around that still runs multiple checks in each test, but it will run all checks and then print out the details of any failures. You need to run pytest in verbose mode ( I don't mind changing the code back to the multiple assert statements in one test if the preference is to keep them consistent with the mmd tests. A few notes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor comment, but looks good to me and happy for this to be merged in without further review. I can confirm that 23 tests pass and 2 tests are skipped for me!
@@ -151,5 +151,8 @@ attack_scripts/ | |||
results/ | |||
|
|||
# Slurm outputs | |||
slurm_logs/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line was removed in a previous PR and may have been re-added due to the merge conflict? It's no longer needed
tests are failing for me.
azure:
slight difference there in otdd_exact/same_result_only_train! aside, is the expected val for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great and passes on my Mac and Azure.
Have made a suggestion on the tests structure. Happy with whatever you choose but should be consistent across metrics.
Nice work.
tests/test_otdd.py
Outdated
"diff_result": similarity_dict, | ||
"diff_result_only_train": similarity_dict_only_train, | ||
} | ||
failures = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the reporting that we get out of this approach.
You might want to consider using something like pytest-check
that achieves a similar outcome.
failures = [] | |
from pytest_check import check # move this up | |
for scenario, results in test_scenarios.items(): | |
for k in metrics_config: | |
expected_result = metrics_config[k]["expected_results"][scenario] | |
actual_result = results[k] | |
with check: | |
assert np.isclose(actual_result, expected_result, rtol=1e-5, atol=1e-8) |
if you run pytest in verbose then you still get the detailed output of mismatched values as you've captured here.
If y
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
happy to keep this if you'd prefer not to introduce another dependency.
Whichever approach is taken, we should do the same in the mmd tests.
The main features of the PR are:
DistanceMetric
classmmd.py
script and created MMD as a child class ofDistanceMetric
OTDD
class that calculates the otdd between two datasetscompute_similarity
method ofDMPair
so that it returns labels as well as the data, also it now returns all data as Numpy arrays (previously it could be either Numpy or a Tensor depending on whether it was only using train data or not)test_otdd.py
script for testing theOTDD
classSome thoughts on testing:
metrics.yaml
file - currently there are only two options there and there are a lot of parameters to experiment withotdd_naive_upperbound
test parameters. After setting a random seed when initialising theOTDD
object the tests no longer failed... however, I am not 100% sure if this is a definite fix.calculate_metrics.py
script also needs testingSome further work that needs considering after this:
OTDD
- will the exact calculation be too computationally intensive? (raised issue Determine parameters for OTDD #48)OTDD
class (raised issue Other implementations of OT #49)