-
-
Notifications
You must be signed in to change notification settings - Fork 165
/
setup.py
50 lines (43 loc) · 1.34 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import platform
from setuptools import setup
from pathlib import Path
BASE_VERSION = '1.8.0' # update regardless whether you update keras or pytorch or both.
FRAMEWORK = os.getenv('FRAMEWORK', 'keras') # keras, pytorch.
if Path('.torch').exists():
FRAMEWORK = 'pytorch'
# common packages.
INSTALL_REQUIRES = [
'numpy',
'keract', # for the intermediate outputs
'pandas',
'matplotlib',
'protobuf<=3.20.2'
]
M1_MAC = platform.system() == 'Darwin' and platform.processor() == 'arm'
if M1_MAC:
tensorflow = 'tensorflow-macos'
# https://github.com/grpc/grpc/issues/25082
os.environ['GRPC_PYTHON_BUILD_SYSTEM_OPENSSL'] = '1'
os.environ['GRPC_PYTHON_BUILD_SYSTEM_ZLIB'] = '1'
else:
tensorflow = 'tensorflow'
if FRAMEWORK == 'keras':
LIB_PACKAGE = ['nbeats_keras']
INSTALL_REQUIRES.extend(['keras', tensorflow])
elif FRAMEWORK == 'pytorch':
LIB_PACKAGE = ['nbeats_pytorch']
INSTALL_REQUIRES.extend(['torch'])
else:
raise ValueError('Unknown framework.')
setup(
name=f'nbeats-{FRAMEWORK}',
version=BASE_VERSION,
description='N-Beats',
author='Philippe Remy (Pytorch), Jean Sebastien Dhr (Keras)',
license='MIT',
long_description_content_type='text/markdown',
long_description=open('README.md').read(),
packages=LIB_PACKAGE,
install_requires=INSTALL_REQUIRES
)