Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7 implement ot #47

Merged
merged 28 commits into from
May 9, 2023
Merged

7 implement ot #47

merged 28 commits into from
May 9, 2023

Conversation

joannacknight
Copy link
Contributor

@joannacknight joannacknight commented Apr 21, 2023

The main features of the PR are:

  • Introduced a new DistanceMetric class
  • Removed the mmd.py script and created MMD as a child class of DistanceMetric
  • Introduced a new OTDD class that calculates the otdd between two datasets
  • Updated the compute_similarity method of DMPair so that it returns labels as well as the data, also it now returns all data as Numpy arrays (previously it could be either Numpy or a Tensor depending on whether it was only using train data or not)
  • Introduced a test_otdd.py script for testing the OTDD class

Some thoughts on testing:

  • The PyTest tests can be run as they are but further otdd configurations could be added into the metrics.yaml file - currently there are only two options there and there are a lot of parameters to experiment with
  • As the mmd code has changed, it is necessary to verify this still works and returns the expected results
  • When I was testing I experienced the otdd tests failing inconsistently for the otdd_naive_upperbound test parameters. After setting a random seed when initialising the OTDD object the tests no longer failed... however, I am not 100% sure if this is a definite fix.
  • The calculate_metrics.py script also needs testing

Some further work that needs considering after this:

Copy link
Contributor

@lannelin lannelin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I've left a few comments and happy to discuss further.
Re. points you made in opening comment:

The PyTest tests can be run as they are but further otdd configurations could be added into the metrics.yaml file - currently there are only two options there and there are a lot of parameters to experiment with

I don't think we should worry too much about testing all param combinations, we should be able to assume that otdd has been tested. We just need to sanity check we're calling it correctly and I think we've done that

When I was testing I experienced the otdd tests failing inconsistently for the otdd_naive_upperbound test parameters. After setting a random seed when initialising the OTDD object the tests no longer failed... however, I am not 100% sure if this is a definite fix.

ok seems reasonable. have made a specific comment on the seeding

The calculate_metrics.py script also needs testing

Agreed, up to you whether you want to create a dedicated test_calculate_metrics.py or test locally

Further safeguarding to prevent metrics being calculated on Apple M1 cpu

agree, have made some comments

The parameters to be set for OTDD - will the exact calculation be too computationally intensive

very good point. Let's try it out on Baskerville and make some adjustments if necessary. can be done outside of this PR

Do we want any other implementations of OT? If so, I envisage them as being set up as their own class rather than an option in the OTDD class

maybe and yes I think a separate class would be appropriate

Copy link
Contributor

@philswatton philswatton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks really good and a huge improvement on the previous approach to the metrics. I've suggested some fairly minor changes

.gitignore Outdated
Comment on lines 153 to 156
# Slurm outputs
slurm_logs/

.DS_Store
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With #46 merged into develop there's going to be a merge conflict on this file (and maybe some others) - worth doing git merge and resolving it as part of this PR

@philswatton
Copy link
Contributor

Also, agreed re OT requiring a separate implementation to OTDD

@joannacknight joannacknight linked an issue Apr 27, 2023 that may be closed by this pull request
Copy link
Contributor

@philswatton philswatton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. The merge needs to be done and resolved, and I have some additional comments (one of which I really should have thought of last time!). Overall looks better again!

# When exact calculations are used then this will be zero
# Approximate methods may be non-zero and these are checked against the known value
# for the seed
def test_cifar_otdd_same(metrics_config: dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still getting a failure on both tests even with a fresh installation of everything, so we may need to settle for the 'within a range' test for the OTDD tests!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, the next update will change the tests to check for 'closeness' rather than equality

Comment on lines 15 to 20
def __init__(self, seed: int):
# Kernel dictionary
self.__MMD_KERNEL_DICT = {
"rbf": metrics.pairwise.rbf_kernel,
"laplace": metrics.pairwise.laplacian_kernel,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing super().__init__(seed) call

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this in though thought the mmd calculations were deterministic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh true! My bad!

Copy link
Contributor

@lannelin lannelin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, nice work. Just very minor points and then the issue (as discussed on slack) with checking values are close rather than exact matches.

@@ -367,11 +367,19 @@ def test_get_AB_data():
train_data_b, val_data_b = dmpair.get_B_data()

# a train
_compare_dataloader_to_tensor(dl=dmpair.A.train_dataloader(), data=train_data_a)
_compare_dataloader_to_tensor(
dl=dmpair.A.train_dataloader(), data=torch.tensor(train_data_a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.from_numpy rather than torch.tensor - from_numpy ensures uses the same memory https://pytorch.org/docs/stable/generated/torch.from_numpy.html

platform.processor() == "arm",
reason="These tests should not be run on Apple M1 devices",
)
def test_cifar_otdd_different(metrics_config: dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to Phil's comment, and as discussed on slack, different vals on my laptop vs azure - closeness probably best here
FAILED tests/test_otdd.py::test_cifar_otdd_different - assert 98.62165832519531 == 98.62129211425781

@joannacknight
Copy link
Contributor Author

The updates to the otdd tests are now ready to be tested.

I wanted to write the tests so that only one assert statement would be made for each test - it would be best practice to implement them that way so that all the metric configs get tested. Currently if the first one fails, then the others don't run and I thought it would be useful when debugging to know if it's only one that has failed or all of them, for example.

I thought this should be possible by parametrising the tests (https://docs.pytest.org/en/7.1.x/example/parametrize.html). I could write the tests in pseudo-code and then had an L&D day on Friday trying to get this working in Pytest, however it doesn’t seem possible to use fixtures directly as arguments, e.g. I wanted to write something like the following, but it throws an error due to metrics_config being a fixture:
@pytest.mark.parametrize(‘metric_config’, [metric_config for metric_config in metrics_config])

I couldn’t find a way around it so I’ve implemented a work-around that still runs multiple checks in each test, but it will run all checks and then print out the details of any failures. You need to run pytest in verbose mode (-v) in order to see the details of the failures if there are many. My preference is to run pytest --disable-warnings -v tests to get rid of all the warning messages too.

I don't mind changing the code back to the multiple assert statements in one test if the preference is to keep them consistent with the mmd tests.

A few notes:

  • I have set the expected values for the otdd tests in the metrics.yaml file to be the values that James posted in slack
  • With the expected values set as they are:
    • Two otdd tests should be skipped for Phil, and the third (checking that an error is thrown) should pass
    • Two otdd tests should pass for James, and the third should be skipped
  • I have used the default tolerance values in the isclose calculation - do these need changing? May depend on James’ results
  • It is obviously possible to amend the expected values in the metrics.yaml file and the code in the skip statements to test the tests using different scenarios.

Copy link
Contributor

@philswatton philswatton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor comment, but looks good to me and happy for this to be merged in without further review. I can confirm that 23 tests pass and 2 tests are skipped for me!

@@ -151,5 +151,8 @@ attack_scripts/
results/

# Slurm outputs
slurm_logs/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line was removed in a previous PR and may have been re-added due to the merge conflict? It's no longer needed

@lannelin
Copy link
Contributor

lannelin commented May 5, 2023

tests are failing for me.
Macbook:

E           Failed: test:otdd_exact/same_result, expected result: 0.0, actual result: 98.62165069580078
E           test:otdd_naive_upperbound/same_result, expected result: 256.1785888671875, actual result: 306.5455017089844
E           test:otdd_exact/same_result_only_train, expected result: 0.0, actual result: 98.76387786865234
E           test:otdd_naive_upperbound/same_result_only_train, expected result: 256.1785888671875, actual result: 306.96649169921875

azure:

E           Failed: test:otdd_exact/same_result, expected result: 0.0, actual result: 98.62165069580078
E           test:otdd_naive_upperbound/same_result, expected result: 256.1785888671875, actual result: 306.5455017089844
E           test:otdd_exact/same_result_only_train, expected result: 0.0, actual result: 98.76387023925781
E           test:otdd_naive_upperbound/same_result_only_train, expected result: 256.1785888671875, actual result: 306.96649169921875

slight difference there in otdd_exact/same_result_only_train!

aside, is the expected val for same_result_only_train worryingly high, perhaps not?

Copy link
Contributor

@lannelin lannelin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great and passes on my Mac and Azure.

Have made a suggestion on the tests structure. Happy with whatever you choose but should be consistent across metrics.

Nice work.

"diff_result": similarity_dict,
"diff_result_only_train": similarity_dict_only_train,
}
failures = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the reporting that we get out of this approach.
You might want to consider using something like pytest-check that achieves a similar outcome.

Suggested change
failures = []
from pytest_check import check # move this up
for scenario, results in test_scenarios.items():
for k in metrics_config:
expected_result = metrics_config[k]["expected_results"][scenario]
actual_result = results[k]
with check:
assert np.isclose(actual_result, expected_result, rtol=1e-5, atol=1e-8)

if you run pytest in verbose then you still get the detailed output of mismatched values as you've captured here.

If y

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to keep this if you'd prefer not to introduce another dependency.
Whichever approach is taken, we should do the same in the mmd tests.

@joannacknight joannacknight merged commit 0d205d5 into develop May 9, 2023
@joannacknight joannacknight deleted the 7-implement-ot branch May 9, 2023 09:00
@philswatton philswatton mentioned this pull request May 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement Optimal Transport Metrics
3 participants