Skip to content

Commit 5bb69b0

Browse files
Rong Rongmattip
Rong Rong
andauthored
concantenate LICENSE files when building a wheel (pytorch#51634) (pytorch#51882)
Summary: Fixes pytorch#50695 I checked locally that the concatenated license file appears at `torch-<version>.dist-info/LICENSE` in the wheel. Pull Request resolved: pytorch#51634 Reviewed By: zhangguanheng66 Differential Revision: D26225550 Pulled By: walterddr fbshipit-source-id: 830c59fb7aea0eb50b99e295edddad9edab6ba3a Co-authored-by: mattip <[email protected]>
1 parent 9112f4e commit 5bb69b0

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

setup.py

+45
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,50 @@ def load(filename):
552552
with open('compile_commands.json', 'w') as f:
553553
f.write(new_contents)
554554

555+
class concat_license_files():
556+
"""Merge LICENSE and LICENSES_BUNDLED.txt as a context manager
557+
558+
LICENSE is the main PyTorch license, LICENSES_BUNDLED.txt is auto-generated
559+
from all the licenses found in ./third_party/. We concatenate them so there
560+
is a single license file in the sdist and wheels with all of the necessary
561+
licensing info.
562+
"""
563+
def __init__(self):
564+
self.f1 = 'LICENSE'
565+
self.f2 = 'third_party/LICENSES_BUNDLED.txt'
566+
567+
def __enter__(self):
568+
"""Concatenate files"""
569+
with open(self.f1, 'r') as f1:
570+
self.bsd_text = f1.read()
571+
572+
with open(self.f1, 'a') as f1:
573+
with open(self.f2, 'r') as f2:
574+
self.bundled_text = f2.read()
575+
f1.write('\n\n')
576+
f1.write(self.bundled_text)
577+
578+
def __exit__(self, exception_type, exception_value, traceback):
579+
"""Restore content of f1"""
580+
with open(self.f1, 'w') as f:
581+
f.write(self.bsd_text)
582+
583+
584+
try:
585+
from wheel.bdist_wheel import bdist_wheel
586+
except ImportError:
587+
# This is useful when wheel is not installed and bdist_wheel is not
588+
# specified on the command line. If it _is_ specified, parsing the command
589+
# line will fail before wheel_concatenate is needed
590+
wheel_concatenate = None
591+
else:
592+
# Need to create the proper LICENSE.txt for the wheel
593+
class wheel_concatenate(bdist_wheel):
594+
""" check submodules on sdist to prevent incomplete tarballs """
595+
def run(self):
596+
with concat_license_files():
597+
super().run()
598+
555599

556600
class install(setuptools.command.install.install):
557601
def run(self):
@@ -724,6 +768,7 @@ def make_relative_rpath_args(path):
724768
'build_ext': build_ext,
725769
'clean': clean,
726770
'install': install,
771+
'bdist_wheel': wheel_concatenate,
727772
}
728773

729774
entry_points = {

test/run_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
'test_fx_experimental',
109109
'test_functional_autograd_benchmark',
110110
'test_package',
111+
'test_license',
111112
'distributed/pipeline/sync/skip/test_api',
112113
'distributed/pipeline/sync/skip/test_gpipe',
113114
'distributed/pipeline/sync/skip/test_inspect_skip_layout',

test/test_license.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import glob
12
import io
3+
import os
24
import unittest
35

6+
import torch
47
from torch.testing._internal.common_utils import TestCase, run_tests
58

69

@@ -10,11 +13,14 @@
1013
create_bundled = None
1114

1215
license_file = 'third_party/LICENSES_BUNDLED.txt'
16+
starting_txt = 'The Pytorch repository and source distributions bundle'
17+
site_packages = os.path.dirname(os.path.dirname(torch.__file__))
18+
distinfo = glob.glob(os.path.join(site_packages, 'torch-*dist-info'))
1319

1420
class TestLicense(TestCase):
1521

1622
@unittest.skipIf(not create_bundled, "can only be run in a source tree")
17-
def test_license_in_wheel(self):
23+
def test_license_for_wheel(self):
1824
current = io.StringIO()
1925
create_bundled('third_party', current)
2026
with open(license_file) as fid:
@@ -25,6 +31,18 @@ def test_license_in_wheel(self):
2531
'match the current state of the third_party files. Use '
2632
'"python third_party/build_bundled.py" to regenerate it')
2733

34+
@unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test")
35+
def test_distinfo_license(self):
36+
"""If run when pytorch is installed via a wheel, the license will be in
37+
site-package/torch-*dist-info/LICENSE. Make sure it contains the third
38+
party bundle of licenses"""
39+
40+
if len(distinfo) > 1:
41+
raise AssertionError('Found too many "torch-*dist-info" directories '
42+
f'in "{site_packages}, expected only one')
43+
with open(os.path.join(os.path.join(distinfo[0], 'LICENSE'))) as fid:
44+
txt = fid.read()
45+
self.assertTrue(starting_txt in txt)
2846

2947
if __name__ == '__main__':
3048
run_tests()

0 commit comments

Comments
 (0)