diff --git a/analyze-gemini.md b/analyze-gemini.md
new file mode 100644
index 00000000..778a324f
--- /dev/null
+++ b/analyze-gemini.md
@@ -0,0 +1,92 @@
+# MuseTalk 环境兼容性分析报告 (Gemini CLI) - V10 (终极方案: 原生适配 Blackwell)
+
+## 1. 基础环境现状
+- **操作系统**: Ubuntu 24.04.3 LTS (noble)
+- **Python**: 3.10 (Conda 虚拟环境: musetalk)
+- **CUDA**: 12.8 (Driver 570.211.01)
+- **GPU**: NVIDIA GeForce RTX 5070 (12GB 显存)
+- **FFmpeg**: 6.1.1 (系统已安装,状态良好)
+- **系统库**: `libsndfile1` 已预装
+
+## 2. 核心部署结论
+
+- **硬件支持**: **完全支持**。12GB 显存满足 MuseTalk 推理需求,RTX 50 系列架构性能卓越。
+- **软件策略**: **执行终极原生方案 (方案 B)**。
+ - 采用原生支持 RTX 50 系列 Blackwell 架构 (`sm_120`) 的 `PyTorch 2.7.0 (cu128)`。由于硬件架构太新,旧版 PyTorch 2.4 的 PTX 无法向前兼容。
+ - 为了适配 PyTorch 2.7.0,项目源码中的 8 处兼容性报错(`torch.load` 等)已被自动修改完毕。
+ - 采用**源码硬核编译安装 `mmcv`** 路线,突破无预编译包的限制。
+- **风险点控制**:
+ - **NumPy 必须 < 2.0.0**:NumPy 2.x 会导致 OpenCV/Librosa 崩溃。
+ - **Huggingface_hub == 0.36.2**:必须满足 transformers 4.39.2 (<1.0) 约束。
+ - **Gradio == 5.24.0**:确保 UI 稳定性并兼容旧版 Pillow。
+ - **Pillow 必须 < 10.0.0**:Pillow 10+ 移除了 `Image.ANTIALIAS`,会导致 `moviepy 1.0.3` 运行时直接崩溃。
+
+## 3. 标准部署步骤 (用户态执行)
+
+### A. 环境准备 (建议彻底重置)
+为避免之前的失败安装残留幽灵依赖,强烈建议**先删后建**:
+```bash
+conda deactivate
+conda remove -n musetalk --all -y
+conda create -n musetalk python=3.10 -y
+conda activate musetalk
+pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
+```
+
+### B. 安装原生架构 PyTorch
+```bash
+# 安装原生满血支持 Blackwell (sm_120) 的最新引擎
+pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 --force-reinstall
+```
+*(注意:MMCV 的安装已移至后续源码编译环节)*
+
+### C. 预装关键冲突依赖 (重要顺序)
+```bash
+# 1. 预装构建工具并锁死 NumPy (Python 3.12 必需)
+pip install --upgrade pip setuptools wheel "numpy<2.0.0"
+
+# 2. 解决 chumpy 编译问题 (避免 PEP 517 构建隔离报错)
+pip install chumpy --no-build-isolation
+
+# 3. 强行锁定三方黄金版本 (解决 transformers/gradio 冲突)
+pip uninstall hf-gradio -y
+pip install "huggingface_hub==0.36.2" "gradio==5.24.0" "gradio-client==1.8.0" --force-reinstall
+
+# 4. 再次检查并加固 NumPy (防止被上面的安装带跑)
+pip install "numpy<2.0.0"
+```
+
+### D. 安装 MMLab 核心组件 (使用原生 pip 源码硬核编译)
+由于我们使用了 PyTorch 2.7.0 + CUDA 12.8,官方**绝对没有**可用的预编译包,必须进行硬核的源码编译:
+```bash
+pip install -U openmim ninja Cython
+# 关键步骤:Ubuntu 24.04 默认 GCC 13 编译 CUDA 极易失败,必须使用 -allow-unsupported-compiler 放宽限制。
+# 同时必须禁用编译隔离(--no-build-isolation),否则找不到构建依赖 pkg_resources!
+NVCC_APPEND_FLAGS="-allow-unsupported-compiler" pip install "mmcv>=2.1.0" --no-build-isolation --no-cache-dir
+mim install "mmdet>=3.1.0"
+mim install "mmpose>=1.3.0"
+```
+
+### E. 安装项目剩余依赖
+```bash
+pip install -r requirements-rtx5070.txt
+```
+
+## 4. 故障排查与修复记录 (Troubleshooting)
+
+| 问题 | 原因 | 修复方式 |
+|------|------|----------|
+| `ModuleNotFoundError: No module named 'mmdet'` | `mmpose` 的核心依赖缺失 | 执行 `mim install "mmdet>=3.1.0"` |
+| `ImportError: cannot import name 'FaceAlignment'` | 全局包冲突 | `pip uninstall face_detection -y`;代码已改为相对导入 |
+| `ImportError: huggingface-hub...` | transformers 版本冲突 | 强行执行 `pip install "huggingface_hub==0.36.2" --force-reinstall` |
+| `AttributeError: module 'numpy' has no attribute...` | 使用了 NumPy 2.x | 强行降级 `pip install "numpy<2.0.0"` |
+| `chumpy` 安装失败 | 隔离构建模式问题 (setup.py import pip) | 使用 `--no-build-isolation` 且需提前装好 setuptools |
+| `AttributeError: module 'PIL.Image' has no attribute 'ANTIALIAS'` | Pillow 10+ 移除了该属性,与 moviepy 1.0.3 冲突 | 强制锁定 `pip install "pillow<10.0.0"` |
+
+## 5. 验证与运行
+1. **核心验证**: 运行 `python3 -c "import mmcv; from mmcv.ops import nms; print('验证成功: OK')"`
+2. **下载权重**: `bash download_weights.sh`
+3. **推理测试**: `bash inference.sh v1.5 normal`
+
+---
+*更新日期: 2026年5月8日 (V10 - 回到方案B,原生支持 RTX 5070 并完成源码改造)*
diff --git a/app.py b/app.py
index 448e641f..c7906385 100644
--- a/app.py
+++ b/app.py
@@ -168,6 +168,17 @@ def download_model():
from musetalk.utils.blending import get_image
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
+import torch
+
+# --- PyTorch 2.6+ Compatibility Monkey Patch ---
+_original_load = torch.load
+def _patched_load(*args, **kwargs):
+ if 'weights_only' not in kwargs:
+ kwargs['weights_only'] = False
+ return _original_load(*args, **kwargs)
+torch.load = _patched_load
+# ---------------------------------------------
+
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
diff --git a/chumpy-0.70/Makefile b/chumpy-0.70/Makefile
new file mode 100644
index 00000000..4d717521
--- /dev/null
+++ b/chumpy-0.70/Makefile
@@ -0,0 +1,18 @@
+all:
+
+upload:
+ rm -r dist
+ python setup.py sdist
+ twine upload dist/*
+
+test:
+ # For some reason the import changes for Python 3 caused the Python 2 test
+ # loader to give up without loading any tests. So we discover them ourselves.
+ # python -m unittest
+ find chumpy -name 'test_*.py' | sed -e 's/\.py$$//' -e 's/\//./' | xargs python -m unittest
+
+coverage: clean qcov
+qcov: all
+ env LD_PRELOAD=$(PRELOADED) coverage run --source=. -m unittest discover -s .
+ coverage html
+ coverage report -m
diff --git a/chumpy-0.70/PKG-INFO b/chumpy-0.70/PKG-INFO
new file mode 100644
index 00000000..bf699007
--- /dev/null
+++ b/chumpy-0.70/PKG-INFO
@@ -0,0 +1,19 @@
+Metadata-Version: 1.1
+Name: chumpy
+Version: 0.70
+Summary: chumpy
+Home-page: https://github.com/mattloper/chumpy
+Author: Matthew Loper
+Author-email: matt.loper@gmail.com
+License: MIT
+Description: UNKNOWN
+Platform: UNKNOWN
+Classifier: Development Status :: 4 - Beta
+Classifier: Intended Audience :: Science/Research
+Classifier: Topic :: Scientific/Engineering :: Mathematics
+Classifier: License :: OSI Approved :: MIT License
+Classifier: Programming Language :: Python :: 2
+Classifier: Programming Language :: Python :: 2.7
+Classifier: Programming Language :: Python :: 3
+Classifier: Operating System :: MacOS :: MacOS X
+Classifier: Operating System :: POSIX :: Linux
diff --git a/chumpy-0.70/chumpy/__init__.py b/chumpy-0.70/chumpy/__init__.py
new file mode 100644
index 00000000..21e366e0
--- /dev/null
+++ b/chumpy-0.70/chumpy/__init__.py
@@ -0,0 +1,118 @@
+from .ch import *
+from .logic import *
+
+from .optimization import minimize
+from . import extras
+from . import testing
+from .version import version as __version__
+
+from .version import version as __version__
+
+from numpy import bool, int, float, complex, object, unicode, str, nan, inf
+
+def test():
+ from os.path import split
+ import unittest
+ test_loader= unittest.TestLoader()
+ test_loader = test_loader.discover(split(__file__)[0])
+ test_runner = unittest.TextTestRunner()
+ test_runner.run( test_loader )
+
+
+demos = {}
+
+demos['scalar'] = """
+import chumpy as ch
+
+[x1, x2, x3] = ch.array(10), ch.array(20), ch.array(30)
+result = x1+x2+x3
+print result # prints [ 60.]
+print result.dr_wrt(x1) # prints 1
+"""
+
+demos['show_tree'] = """
+import chumpy as ch
+
+[x1, x2, x3] = ch.array(10), ch.array(20), ch.array(30)
+for i in range(3): x2 = x1 + x2 + x3
+
+x2.dr_wrt(x1) # pull cache
+x2.dr_wrt(x3) # pull cache
+x1.label='x1' # for clarity in show_tree()
+x2.label='x2' # for clarity in show_tree()
+x3.label='x3' # for clarity in show_tree()
+x2.show_tree(cachelim=1e-4) # in MB
+"""
+
+demos['matrix'] = """
+import chumpy as ch
+
+x1, x2, x3, x4 = ch.eye(10), ch.array(1), ch.array(5), ch.array(10)
+y = x1*(x2-x3)+x4
+print y
+print y.dr_wrt(x2)
+"""
+
+demos['linalg'] = """
+import chumpy as ch
+
+m = [ch.random.randn(100).reshape((10,10)) for i in range(3)]
+y = m[0].dot(m[1]).dot(ch.linalg.inv(m[2])) * ch.linalg.det(m[0])
+print y.shape
+print y.dr_wrt(m[0]).shape
+"""
+
+demos['inheritance'] = """
+import chumpy as ch
+import numpy as np
+
+class Sin(ch.Ch):
+
+ dterms = ('x',)
+
+ def compute_r(self):
+ return np.sin(self.x.r)
+
+ def compute_dr_wrt(self, wrt):
+ import scipy.sparse
+ if wrt is self.x:
+ result = np.cos(self.x.r)
+ return scipy.sparse.diags([result.ravel()], [0]) if len(result)>1 else np.atleast_2d(result)
+
+x1 = Ch([10,20,30])
+result = Sin(x1) # or "result = Sin(x=x1)"
+print result.r
+print result.dr_wrt(x1)
+"""
+
+demos['optimization'] = """
+import chumpy as ch
+
+x = ch.zeros(10)
+y = ch.zeros(10)
+
+# Beale's function
+e1 = 1.5 - x + x*y
+e2 = 2.25 - x + x*(y**2)
+e3 = 2.625 - x + x*(y**3)
+
+objective = {'e1': e1, 'e2': e2, 'e3': e3}
+ch.minimize(objective, x0=[x,y], method='dogleg')
+print x # should be all 3.0
+print y # should be all 0.5
+"""
+
+
+
+
+def demo(which=None):
+ if which not in demos:
+ print('Please indicate which demo you want, as follows:')
+ for key in demos:
+ print("\tdemo('%s')" % (key,))
+ return
+
+ print('- - - - - - - - - - - - - - - - - - - - - - -')
+ print(demos[which])
+ print('- - - - - - - - - - - - - - - - - - - - - - -\n')
+ exec('global np\n' + demos[which], globals(), locals())
diff --git a/chumpy-0.70/chumpy/api_compatibility.py b/chumpy-0.70/chumpy/api_compatibility.py
new file mode 100644
index 00000000..5e8ef23c
--- /dev/null
+++ b/chumpy-0.70/chumpy/api_compatibility.py
@@ -0,0 +1,534 @@
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+
+from . import ch
+import numpy as np
+from os.path import join, split
+from six import StringIO
+import numpy
+import chumpy
+from six.moves import cPickle as pickle
+
+src = ''
+num_passed = 0
+num_not_passed = 0
+which_passed = []
+
+def r(fn_name, args_req, args_opt, nplib=numpy, chlib=chumpy):
+ global num_passed, num_not_passed
+ result = [None, None]
+
+ for lib in [nplib, chlib]:
+
+ # if fn_name is 'svd' and lib is chlib:
+ # import pdb; pdb.set_trace()
+ if lib is nplib:
+ fn = getattr(lib, fn_name)
+ else:
+ try:
+ fn = getattr(lib, fn_name)
+ except AttributeError:
+ result[0] = 'missing'
+ result[1] = 'missing'
+ num_not_passed += 1
+ continue
+ try:
+ if isinstance(args_req, dict):
+ _ = fn(**args_req)
+ else:
+ _ = fn(*args_req)
+ if lib is chlib:
+ result[0] = 'passed'
+ num_passed += 1
+ global which_passed
+ which_passed.append(fn_name)
+
+ if hasattr(_, 'dterms'):
+ try:
+ _.r
+
+ try:
+ pickle.dumps(_)
+ except:
+ result[0] += ' (but unpickleable!)'
+ except:
+ import pdb; pdb.set_trace()
+ result[0] += '(but cant get result!)'
+ except Exception as e:
+ if e is TypeError:
+ import pdb; pdb.set_trace()
+ if lib is nplib:
+ import pdb; pdb.set_trace()
+ else:
+ num_not_passed += 1
+ # if fn_name == 'rot90':
+ # import pdb; pdb.set_trace()
+ result[0] = e.__class__.__name__
+
+ try:
+ if isinstance(args_req, dict):
+ fn(**dict(list(args_req.items()) + list(args_opt.items())))
+ else:
+ fn(*args_req, **args_opt)
+ if lib is chlib:
+ result[1] = 'passed'
+ except Exception as e:
+ if e is TypeError:
+ import pdb; pdb.set_trace()
+ result[1] = e.__class__.__name__
+
+ # print '%s: %s, %s' % (fn_name, result[0], result[1])
+
+ append(fn_name, result[0], result[1])
+
+def make_row(a, b, c, b_color, c_color):
+ global src
+ src += '
| %s | %s | %s |
' % (a,b_color, b,c_color, c)
+
+def append(a, b, c):
+ global src
+ b_color = 'white'
+ c_color = 'white'
+
+ b = b.replace('NotImplementedError', 'not yet implemented')
+ c = c.replace('NotImplementedError', 'not yet implemented')
+ b = b.replace('WontImplement', "won't implement")
+ c = c.replace('WontImplement', "won't implement")
+ lookup = {
+ 'passed': 'lightgreen',
+ "won't implement": 'lightgray',
+ 'untested': 'lightyellow',
+ 'not yet implemented': 'pink'
+ }
+
+ b_color = lookup[b] if b in lookup else 'white'
+ c_color = lookup[c] if c in lookup else 'white'
+
+ print('%s: %s, %s' % (a,b,c))
+ make_row(a, b, c, b_color, c_color)
+
+def m(s):
+ append(s, 'unknown', 'unknown')
+ global num_not_passed
+ num_not_passed += 1
+
+def hd3(s):
+ global src
+ src += '%s |
' % (s,)
+
+def hd2(s):
+ global src
+ src += '
'
+ src += '%s |
' % (s,)
+
+def main():
+
+ #sample_array
+
+ ###############################
+ hd2('Array Creation Routines')
+
+ hd3('Ones and zeros')
+
+ r('empty', {'shape': (2,4,2)}, {'dtype': np.uint8, 'order': 'C'})
+ r('empty_like', {'prototype': np.empty((2,4,2))}, {'dtype': np.float64, 'order': 'C'})
+ r('eye', {'N': 10}, {'M': 5, 'k': 0, 'dtype': np.float64})
+ r('identity', {'n': 10}, {'dtype': np.float64})
+ r('ones', {'shape': (2,4,2)}, {'dtype': np.uint8, 'order': 'C'})
+ r('ones_like', {'a': np.empty((2,4,2))}, {'dtype': np.float64, 'order': 'C'})
+ r('zeros', {'shape': (2,4,2)}, {'dtype': np.uint8, 'order': 'C'})
+ r('zeros_like', {'a': np.empty((2,4,2))}, {'dtype': np.float64, 'order': 'C'})
+
+ hd3('From existing data')
+ r('array', {'object': [1,2,3]}, {'dtype': np.float64, 'order': 'C', 'subok': False, 'ndmin': 2})
+ r('asarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
+ r('asanyarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
+ r('ascontiguousarray', {'a': np.array([1,2,3])}, {'dtype': np.float64})
+ r('asmatrix', {'data': np.array([1,2,3])}, {'dtype': np.float64})
+ r('copy', (np.array([1,2,3]),), {})
+ r('frombuffer', {'buffer': np.array([1,2,3])}, {})
+ m('fromfile')
+ r('fromfunction', {'function': lambda i, j: i + j, 'shape': (3, 3)}, {'dtype': np.float64})
+ # function, shape, **kwargs
+ # lambda i, j: i + j, (3, 3), dtype=int
+ r('fromiter', {'iter': [1,2,3,4], 'dtype': np.float64}, {'count': 2})
+ r('fromstring', {'string': '\x01\x02', 'dtype': np.uint8}, {})
+ r('loadtxt', {'fname': StringIO("0 1\n2 3")}, {})
+
+ hd3('Creating record arrays (wont be implemented)')
+ hd3('Creating character arrays (wont be implemented)')
+
+ hd3('Numerical ranges')
+ r('arange', {'start': 0, 'stop': 10}, {'step': 2, 'dtype': np.float64})
+ r('linspace', {'start': 0, 'stop': 10}, {'num': 2, 'endpoint': 10, 'retstep': 1})
+ r('logspace', {'start': 0, 'stop': 10}, {'num': 2, 'endpoint': 10, 'base': 1})
+ r('meshgrid', ([1,2,3], [4,5,6]), {})
+ m('mgrid')
+ m('ogrid')
+
+ hd3('Building matrices')
+ r('diag', {'v': np.arange(9).reshape((3,3))}, {'k': 0})
+ r('diagflat', {'v': [[1,2], [3,4]]}, {})
+ r('tri', {'N': 3}, {'M': 5, 'k': 2, 'dtype': np.float64})
+ r('tril', {'m': [[1,2,3],[4,5,6],[7,8,9],[10,11,12]]}, {'k': -1})
+ r('triu', {'m': [[1,2,3],[4,5,6],[7,8,9],[10,11,12]]}, {'k': -1})
+ r('vander', {'x': np.array([1, 2, 3, 5])}, {'N': 3})
+
+ ###############################
+ hd2('Array manipulation routines')
+
+ hd3('Basic operations')
+ r('copyto', {'dst': np.eye(3), 'src': np.eye(3)}, {})
+
+ hd3('Changing array shape')
+ r('reshape', {'a': np.eye(3), 'newshape': (9,)}, {'order' : 'C'})
+ r('ravel', {'a': np.eye(3)}, {'order' : 'C'})
+ m('flat')
+ m('flatten')
+
+ hd3('Transpose-like operations')
+ r('rollaxis', {'a': np.ones((3,4,5,6)), 'axis': 3}, {'start': 0})
+ r('swapaxes', {'a': np.array([[1,2,3]]), 'axis1': 0, 'axis2': 1}, {})
+ r('transpose', {'a': np.arange(4).reshape((2,2))}, {'axes': (1,0)})
+
+ hd3('Changing number of dimensions')
+ r('atleast_1d', (np.eye(3),), {})
+ r('atleast_2d', (np.eye(3),), {})
+ r('atleast_3d', (np.eye(3),), {})
+ m('broadcast')
+ m('broadcast_arrays')
+ r('expand_dims', (np.array([1,2]),2), {})
+ r('squeeze', {'a': (np.array([[[1,2,3]]]))}, {})
+
+ hd3('Changing kind of array')
+ r('asarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
+ r('asanyarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
+ r('asmatrix', {'data': np.array([1,2,3])}, {})
+ r('asfarray', {'a': np.array([1,2,3])}, {})
+ r('asfortranarray', {'a': np.array([1,2,3])}, {})
+ r('asscalar', {'a': np.array([24])}, {})
+ r('require', {'a': np.array([24])}, {})
+
+ hd3('Joining arrays')
+ m('column_stack')
+ r('concatenate', ((np.eye(3), np.eye(3)),1), {})
+ r('dstack', ((np.eye(3), np.eye(3)),), {})
+ r('hstack', ((np.eye(3), np.eye(3)),), {})
+ r('vstack', ((np.eye(3), np.eye(3)),), {})
+
+ hd3('Splitting arrays')
+ m('array_split')
+ m('dsplit')
+ m('hsplit')
+ m('split')
+ m('vsplit')
+
+ hd3('Tiling arrays')
+ r('tile', (np.array([0, 1, 2]),2), {})
+ r('repeat', (np.array([[1,2],[3,4]]), 3), {'axis': 1})
+
+ hd3('Adding and removing elements')
+ m('delete')
+ m('insert')
+ m('append')
+ m('resize')
+ m('trim_zeros')
+ m('unique')
+
+ hd3('Rearranging elements')
+ r('fliplr', (np.eye(3),), {})
+ r('flipud', (np.eye(3),), {})
+ r('reshape', {'a': np.eye(3), 'newshape': (9,)}, {'order' : 'C'})
+ r('roll', (np.arange(10), 2), {})
+ r('rot90', (np.arange(4).reshape((2,2)),), {})
+
+ ###############################
+ hd2('Linear algebra (numpy.linalg)')
+
+ extra_args = {'nplib': numpy.linalg, 'chlib': ch.linalg}
+
+ hd3('Matrix and dot products')
+ r('dot', {'a': np.eye(3), 'b': np.eye(3)}, {})
+ r('dot', {'a': np.eye(3).ravel(), 'b': np.eye(3).ravel()}, {})
+ r('vdot', (np.eye(3).ravel(), np.eye(3).ravel()), {})
+ r('inner', (np.eye(3).ravel(), np.eye(3).ravel()), {})
+ r('outer', (np.eye(3).ravel(), np.eye(3).ravel()), {})
+ r('tensordot', {'a': np.eye(3), 'b': np.eye(3)}, {})
+ m('einsum')
+ r('matrix_power', {'M': np.eye(3), 'n': 2}, {}, **extra_args)
+ r('kron', {'a': np.eye(3), 'b': np.eye(3)}, {})
+
+ hd3('Decompositions')
+ r('cholesky', {'a': np.eye(3)}, {}, **extra_args)
+ r('qr', {'a': np.eye(3)}, {}, **extra_args)
+ r('svd', (np.eye(3),), {}, **extra_args)
+
+ hd3('Matrix eigenvalues')
+ r('eig', (np.eye(3),), {}, **extra_args)
+ r('eigh', (np.eye(3),), {}, **extra_args)
+ r('eigvals', (np.eye(3),), {}, **extra_args)
+ r('eigvalsh', (np.eye(3),), {}, **extra_args)
+
+ hd3('Norms and other numbers')
+ r('norm', (np.eye(3),), {}, **extra_args)
+ r('cond', (np.eye(3),), {}, **extra_args)
+ r('det', (np.eye(3),), {}, **extra_args)
+ r('slogdet', (np.eye(3),), {}, **extra_args)
+ r('trace', (np.eye(3),), {})
+
+ hd3('Solving equations and inverting matrices')
+ r('solve', (np.eye(3),np.ones(3)), {}, **extra_args)
+ r('tensorsolve', (np.eye(3),np.ones(3)), {}, **extra_args)
+ r('lstsq', (np.eye(3),np.ones(3)), {}, **extra_args)
+ r('inv', (np.eye(3),), {}, **extra_args)
+ r('pinv', (np.eye(3),), {}, **extra_args)
+ r('tensorinv', (np.eye(4*6).reshape((4,6,8,3)),), {'ind': 2}, **extra_args)
+
+ hd2('Mathematical functions')
+
+ hd3('Trigonometric functions')
+ r('sin', (np.arange(3),), {})
+ r('cos', (np.arange(3),), {})
+ r('tan', (np.arange(3),), {})
+ r('arcsin', (np.arange(3)/3.,), {})
+ r('arccos', (np.arange(3)/3.,), {})
+ r('arctan', (np.arange(3)/3.,), {})
+ r('hypot', (np.arange(3),np.arange(3)), {})
+ r('arctan2', (np.arange(3),np.arange(3)), {})
+ r('degrees', (np.arange(3),), {})
+ r('radians', (np.arange(3),), {})
+ r('unwrap', (np.arange(3),), {})
+ r('unwrap', (np.arange(3),), {})
+ r('deg2rad', (np.arange(3),), {})
+ r('rad2deg', (np.arange(3),), {})
+
+ hd3('Hyperbolic functions')
+ r('sinh', (np.arange(3),), {})
+ r('cosh', (np.arange(3),), {})
+ r('tanh', (np.arange(3),), {})
+ r('arcsinh', (np.arange(3)/9.,), {})
+ r('arccosh', (-np.arange(3)/9.,), {})
+ r('arctanh', (np.arange(3)/9.,), {})
+
+ hd3('Rounding')
+ r('around', (np.arange(3),), {})
+ r('round_', (np.arange(3),), {})
+ r('rint', (np.arange(3),), {})
+ r('fix', (np.arange(3),), {})
+ r('floor', (np.arange(3),), {})
+ r('ceil', (np.arange(3),), {})
+ r('trunc', (np.arange(3),), {})
+
+ hd3('Sums, products, differences')
+ r('prod', (np.arange(3),), {})
+ r('sum', (np.arange(3),), {})
+ r('nansum', (np.arange(3),), {})
+ r('cumprod', (np.arange(3),), {})
+ r('cumsum', (np.arange(3),), {})
+ r('diff', (np.arange(3),), {})
+ r('ediff1d', (np.arange(3),), {})
+ r('gradient', (np.arange(3),), {})
+ r('cross', (np.arange(3), np.arange(3)), {})
+ r('trapz', (np.arange(3),), {})
+
+ hd3('Exponents and logarithms')
+ r('exp', (np.arange(3),), {})
+ r('expm1', (np.arange(3),), {})
+ r('exp2', (np.arange(3),), {})
+ r('log', (np.arange(3),), {})
+ r('log10', (np.arange(3),), {})
+ r('log2', (np.arange(3),), {})
+ r('log1p', (np.arange(3),), {})
+ r('logaddexp', (np.arange(3), np.arange(3)), {})
+ r('logaddexp2', (np.arange(3), np.arange(3)), {})
+
+ hd3('Other special functions')
+ r('i0', (np.arange(3),), {})
+ r('sinc', (np.arange(3),), {})
+
+ hd3('Floating point routines')
+ r('signbit', (np.arange(3),), {})
+ r('copysign', (np.arange(3), np.arange(3)), {})
+ r('frexp', (np.arange(3),), {})
+ r('ldexp', (np.arange(3), np.arange(3)), {})
+
+ hd3('Arithmetic operations')
+ r('add', (np.arange(3), np.arange(3)), {})
+ r('reciprocal', (np.arange(3),), {})
+ r('negative', (np.arange(3),), {})
+ r('multiply', (np.arange(3), np.arange(3)), {})
+ r('divide', (np.arange(3), np.arange(3)), {})
+ r('power', (np.arange(3), np.arange(3)), {})
+ r('subtract', (np.arange(3), np.arange(3)), {})
+ r('true_divide', (np.arange(3), np.arange(3)), {})
+ r('floor_divide', (np.arange(3), np.arange(3)), {})
+ r('fmod', (np.arange(3), np.arange(3)), {})
+ r('mod', (np.arange(3), np.arange(3)), {})
+ r('modf', (np.arange(3),), {})
+ r('remainder', (np.arange(3), np.arange(3)), {})
+
+ hd3('Handling complex numbers')
+ m('angle')
+ m('real')
+ m('imag')
+ m('conj')
+
+ hd3('Miscellaneous')
+ r('convolve', (np.arange(3), np.arange(3)), {})
+ r('clip', (np.arange(3), 0, 2), {})
+ r('sqrt', (np.arange(3),), {})
+ r('square', (np.arange(3),), {})
+ r('absolute', (np.arange(3),), {})
+ r('fabs', (np.arange(3),), {})
+ r('sign', (np.arange(3),), {})
+ r('maximum', (np.arange(3), np.arange(3)), {})
+ r('minimum', (np.arange(3), np.arange(3)), {})
+ r('fmax', (np.arange(3), np.arange(3)), {})
+ r('fmin', (np.arange(3), np.arange(3)), {})
+ r('nan_to_num', (np.arange(3),), {})
+ r('real_if_close', (np.arange(3),), {})
+ r('interp', (2.5, [1,2,3], [3,2,0]), {})
+
+ extra_args = {'nplib': numpy.random, 'chlib': ch.random}
+
+ hd2('Random sampling (numpy.random)')
+ hd3('Simple random data')
+ r('rand', (3,), {}, **extra_args)
+ r('randn', (3,), {}, **extra_args)
+ r('randint', (3,), {}, **extra_args)
+ r('random_integers', (3,), {}, **extra_args)
+ r('random_sample', (3,), {}, **extra_args)
+ r('random', (3,), {}, **extra_args)
+ r('ranf', (3,), {}, **extra_args)
+ r('sample', (3,), {}, **extra_args)
+ r('choice', (np.ones(3),), {}, **extra_args)
+ r('bytes', (3,), {}, **extra_args)
+
+ hd3('Permutations')
+ r('shuffle', (np.ones(3),), {}, **extra_args)
+ r('permutation', (3,), {}, **extra_args)
+
+ hd3('Distributions (these all pass)')
+ r('beta', (.5, .5), {}, **extra_args)
+ r('binomial', (.5, .5), {}, **extra_args)
+ r('chisquare', (.5,), {}, **extra_args)
+ r('dirichlet', ((10, 5, 3), 20,), {}, **extra_args)
+ r('exponential', [], {}, **extra_args)
+ r('f', [1,48,1000], {}, **extra_args)
+ r('gamma', [.5], {}, **extra_args)
+ make_row('...AND 28 OTHERS...', 'passed', 'passed', 'lightgreen', 'lightgreen')
+
+
+ hd3('Random generator')
+ r('seed', [], {}, **extra_args)
+ r('get_state', [], {}, **extra_args)
+ r('set_state', [np.random.get_state()], {}, **extra_args)
+
+ ####################################
+ hd2('Statistics')
+ hd3('Order statistics')
+ r('amin', (np.eye(3),),{})
+ r('amax', (np.eye(3),),{})
+ r('nanmin', (np.eye(3),),{})
+ r('nanmax', (np.eye(3),),{})
+ r('ptp', (np.eye(3),),{})
+ r('percentile', (np.eye(3),50),{})
+
+ hd3('Averages and variance')
+ r('median', (np.eye(3),),{})
+ r('average', (np.eye(3),),{})
+ r('mean', (np.eye(3),),{})
+ r('std', (np.eye(3),),{})
+ r('var', (np.eye(3),),{})
+ r('nanmean', (np.eye(3),),{})
+ r('nanstd', (np.eye(3),),{})
+ r('nanvar', (np.eye(3),),{})
+
+
+ hd3('Correlating')
+ r('corrcoef', (np.eye(3),),{})
+ r('correlate', ([1, 2, 3], [0, 1, 0.5]),{})
+ r('cov', (np.eye(3),),{})
+
+ hd3('Histograms')
+ r('histogram', (np.eye(3),),{})
+ r('histogram2d', (np.eye(3).ravel(),np.eye(3).ravel()),{})
+ r('histogramdd', (np.eye(3).ravel(),),{})
+ r('bincount', (np.asarray(np.eye(3).ravel(), np.uint32),),{})
+ r('digitize', (np.array([0.2, 6.4, 3.0, 1.6]), np.array([0.0, 1.0, 2.5, 4.0, 10.0])),{})
+
+ ####################################
+ hd2('Sorting, searching, and counting')
+
+ hd3('Sorting')
+ r('sort', (np.array([1,3,1,2.]),), {})
+ m('lexsort')
+ m('argsort')
+ m('msort')
+ m('sort_complex')
+ m('partition')
+ m('argpartition')
+
+# sort(a[, axis, kind, order]) Return a sorted copy of an array.
+# lexsort(keys[, axis]) Perform an indirect sort using a sequence of keys.
+# argsort(a[, axis, kind, order]) Returns the indices that would sort an array.
+# ndarray.sort([axis, kind, order]) Sort an array, in-place.
+# msort(a) Return a copy of an array sorted along the first axis.
+# sort_complex(a) Sort a complex array using the real part first, then the imaginary part.
+# partition(a, kth[, axis, kind, order]) Return a partitioned copy of an array.
+# argpartition(a, kth[, axis, kind, order]) Perform an indirect partition along the given axis using the algorithm specified by the kind keyword.
+
+ a5 = np.arange(5)
+
+ hd3('Searching')
+ r('argmax', (a5,), {})
+ r('nanargmax', (a5,), {})
+ r('argmin', (a5,), {})
+ r('nanargmin', (a5,), {})
+ r('argwhere', (a5,), {})
+ r('nonzero', (a5,), {})
+ r('flatnonzero', (a5,), {})
+ r('where', (a5>1,), {})
+ r('searchsorted', (a5,a5), {})
+ r('extract', (lambda x : x > 1, a5), {})
+
+# argmax(a[, axis]) Indices of the maximum values along an axis.
+# nanargmax(a[, axis]) Return the indices of the maximum values in the specified axis ignoring
+# argmin(a[, axis]) Return the indices of the minimum values along an axis.
+# nanargmin(a[, axis]) Return the indices of the minimum values in the specified axis ignoring
+# argwhere(a) Find the indices of array elements that are non-zero, grouped by element.
+# nonzero(a) Return the indices of the elements that are non-zero.
+# flatnonzero(a) Return indices that are non-zero in the flattened version of a.
+# where(condition, [x, y]) Return elements, either from x or y, depending on condition.
+# searchsorted(a, v[, side, sorter]) Find indices where elements should be inserted to maintain order.
+# extract(condition, arr) Return the elements of an array that satisfy some condition.
+
+ hd3('Counting')
+ r('count_nonzero', (a5,), {})
+ #count_nonzero(a) Counts the number of non-zero values in the array a.
+
+
+
+# histogram(a[, bins, range, normed, weights, ...]) Compute the histogram of a set of data.
+# histogram2d(x, y[, bins, range, normed, weights]) Compute the bi-dimensional histogram of two data samples.
+# histogramdd(sample[, bins, range, normed, ...]) Compute the multidimensional histogram of some data.
+# bincount(x[, weights, minlength]) Count number of occurrences of each value in array of non-negative ints.
+# digitize(x, bins[, right]) Return the indices of the bins to which each value in input array belongs.
+
+
+ global src
+ src = ''
+ open(join(split(__file__)[0], 'api_compatibility.html'), 'w').write(src)
+
+ print('passed %d, not passed %d' % (num_passed, num_not_passed))
+
+
+
+if __name__ == '__main__':
+ global which_passed
+ main()
+ print(' '.join(which_passed))
diff --git a/chumpy-0.70/chumpy/ch.py b/chumpy-0.70/chumpy/ch.py
new file mode 100644
index 00000000..e903df98
--- /dev/null
+++ b/chumpy-0.70/chumpy/ch.py
@@ -0,0 +1,1366 @@
+#!/usr/bin/env python
+# encoding: utf-8
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+
+__all__ = ['Ch', 'depends_on', 'MatVecMult', 'ChHandle', 'ChLambda']
+
+import os, sys, time
+import inspect
+import scipy.sparse as sp
+import numpy as np
+import numbers
+import weakref
+import copy as external_copy
+from functools import wraps
+from scipy.sparse.linalg.interface import LinearOperator
+from .utils import row, col, timer, convert_inputs_to_sparse_if_necessary
+import collections
+from copy import deepcopy
+from functools import reduce
+
+
+
+# Turn this on if you want the profiler injected
+DEBUG = False
+# Turn this on to make optimizations very chatty for debugging
+VERBOSE = False
+def pif(msg):
+ # print-if-verbose.
+ if DEBUG or VERBOSE:
+ sys.stdout.write(msg + '\n')
+
+_props_for_dict = weakref.WeakKeyDictionary()
+def _props_for(cls):
+ if cls not in _props_for_dict:
+ _props_for_dict[cls] = set([p[0] for p in inspect.getmembers(cls, lambda x : isinstance(x, property))])
+ return _props_for_dict[cls]
+
+_dep_props_for_dict = weakref.WeakKeyDictionary()
+def _dep_props_for(cls):
+ if cls not in _dep_props_for_dict:
+ _dep_props_for_dict[cls] = [p for p in inspect.getmembers(cls, lambda x : isinstance(x, property)) if hasattr(p[1].fget, 'deps')]
+ return _dep_props_for_dict[cls]
+
+
+_kw_conflict_dict = weakref.WeakKeyDictionary()
+def _check_kw_conflict(cls):
+ if cls not in _kw_conflict_dict:
+ _kw_conflict_dict[cls] = Ch._reserved_kw.intersection(set(cls.terms).union(set(cls.dterms)))
+ if _kw_conflict_dict[cls]:
+ raise Exception("In class %s, don't use reserved keywords in terms/dterms: %s" % (str(cls), str(kw_conflict),))
+
+
+class Term(object):
+ creation_counter = 0
+ def __init__(self, default=None, desc=None, dr=True):
+ self.default = default
+ self.desc = desc
+ self.dr = dr
+
+ # Add a creation_counter, a la Django models, so we can preserve the order in which parameters are defined in the job.
+ # http://stackoverflow.com/a/3288801/893113
+ self.creation_counter = Term.creation_counter
+ Term.creation_counter += 1
+
+
+class Ch(object):
+ terms = []
+ dterms = ['x']
+ __array_priority__ = 2.0
+ _cached_parms = {}
+ _setup_terms = {}
+ _default_kwargs = {'make_dense' : False, 'make_sparse' : False}
+ _status = "undefined"
+
+ called_dr_wrt = False
+ profiler = None
+
+ ########################################################
+ # Construction
+
+ def __new__(cls, *args, **kwargs):
+
+ if len(args) > 0 and type(args[0]) == type(lambda : 0):
+ cls = ChLambda
+
+ # Create empty instance
+ result = super(Ch, cls).__new__(cls)
+
+ cls.setup_terms()
+
+ object.__setattr__(result, '_dirty_vars', set())
+ object.__setattr__(result, '_itr', None)
+ object.__setattr__(result, '_parents', weakref.WeakKeyDictionary())
+ object.__setattr__(result, '_cache', {'r': None, 'drs': weakref.WeakKeyDictionary()})
+
+ if DEBUG:
+ object.__setattr__(result, '_cache_info', {})
+ object.__setattr__(result, '_status', 'new')
+
+ for name, default_val in list(cls._default_kwargs.items()):
+ object.__setattr__(result, '_%s' % name, kwargs.get(name, default_val))
+ if name in kwargs:
+ del kwargs[name]
+
+ # Set up storage that allows @depends_on to work
+ #props = [p for p in inspect.getmembers(cls, lambda x : isinstance(x, property)) if hasattr(p[1].fget, 'deps')]
+ props = _dep_props_for(cls)
+ cpd = {}
+ for p in props:
+ func_name = p[0] #id(p[1].fget)
+ deps = p[1].fget.deps
+ cpd[func_name] = {'deps': deps, 'value': None, 'out_of_date': True}
+
+ object.__setattr__(result, '_depends_on_deps', cpd)
+
+ if cls != Ch:
+ for idx, a in enumerate(args):
+ kwargs[cls.term_order[idx]] = a
+ elif len(args)>0:
+ kwargs['x'] = np.asarray(args[0], np.float64)
+
+ defs = {p.name : deepcopy(p.default) for p in cls.parm_declarations() if p.default is not None}
+ defs.update(kwargs)
+ result.set(**defs)
+
+ return result
+
+ @classmethod
+ def parm_declarations(cls):
+ if cls.__name__ not in cls._cached_parms:
+ parameter_declarations = collections.OrderedDict()
+ parameters = inspect.getmembers(cls, lambda x: isinstance(x, Term))
+ for name, decl in sorted(parameters, key=lambda x: x[1].creation_counter):
+ decl.name = name
+ parameter_declarations[name] = decl
+ cls._cached_parms[cls.__name__] = parameter_declarations
+ return cls._cached_parms[cls.__name__]
+
+ @classmethod
+ def setup_terms(cls):
+ if id(cls) in cls._setup_terms: return
+
+ if cls == Ch:
+ return
+
+ parm_declarations = cls.parm_declarations()
+
+ if cls.dterms is Ch.dterms:
+ cls.dterms = []
+ elif isinstance(cls.dterms, str):
+ cls.dterms = (cls.dterms,)
+ if cls.terms is Ch.terms:
+ cls.terms = []
+ elif isinstance(cls.terms, str):
+ cls.terms = (cls.terms,)
+
+ # Must be either new or old style
+ len_oldstyle_parms = len(cls.dterms)+len(cls.terms)
+ if len(parm_declarations) > 0:
+ assert(len_oldstyle_parms==0)
+ cls.term_order = [t.name for t in parm_declarations]
+ cls.dterms = [t.name for t in parm_declarations if t.dr]
+ cls.terms = [t.name for t in parm_declarations if not t.dr]
+ else:
+ if not hasattr(cls, 'term_order'):
+ cls.term_order = list(cls.terms) + list(cls.dterms)
+
+ _check_kw_conflict(cls)
+ cls._setup_terms[id(cls)] = True
+
+
+ ########################################################
+ # Identifiers
+
+ @property
+ def short_name(self):
+ return self.label if hasattr(self, 'label') else self.__class__.__name__
+
+ @property
+ def sid(self):
+ """Semantic id."""
+ pnames = list(self.terms)+list(self.dterms)
+ pnames.sort()
+ return (self.__class__, tuple([(k, id(self.__dict__[k])) for k in pnames if k in self.__dict__]))
+
+
+ def reshape(self, *args):
+ return reshape(a=self, newshape=args if len(args)>1 else args[0])
+
+ def ravel(self):
+ return reshape(a=self, newshape=(-1))
+
+ def __hash__(self):
+ return id(self)
+
+ @property
+ def ndim(self):
+ return self.r.ndim
+
+ @property
+ def flat(self):
+ return self.r.flat
+
+ @property
+ def dtype(self):
+ return self.r.dtype
+
+ @property
+ def itemsize(self):
+ return self.r.itemsize
+
+
+ ########################################################
+ # Redundancy removal
+
+ def remove_redundancy(self, cache=None, iterate=True):
+
+ if cache == None:
+ cache = {}
+ _ = self.r # may result in the creation of extra dterms that we can cull
+
+ replacement_occurred = False
+ for propname in list(self.dterms):
+ prop = self.__dict__[propname]
+
+ if not hasattr(prop, 'dterms'):
+ continue
+ sid = prop.sid
+ if sid not in cache:
+ cache[sid] = prop
+ elif self.__dict__[propname] is not cache[sid]:
+ self.__dict__[propname] = cache[sid]
+ replacement_occurred = True
+ if prop.remove_redundancy(cache, iterate=False):
+ replacement_occurred = True
+
+ if not replacement_occurred:
+ return False
+ else:
+ if iterate:
+ self.remove_redundancy(cache, iterate=True)
+ return False
+ else:
+ return True
+
+
+
+ def print_labeled_residuals(self, print_newline=True, num_decimals=2, where_to_print=None):
+
+ if where_to_print is None:
+ where_to_print = sys.stderr
+ if hasattr(self, 'label'):
+ where_to_print.write(('%s: %.' + str(num_decimals) + 'e | ') % (self.label, np.sum(self.r**2)))
+ for dterm in self.dterms:
+ dt = getattr(self, dterm)
+ if hasattr(dt, 'dterms'):
+ dt.print_labeled_residuals(print_newline=False, where_to_print=where_to_print)
+ if print_newline:
+ where_to_print.write(('%.' + str(num_decimals) + 'e\n') % (np.sum(self.r**2),))
+
+
+
+ ########################################################
+ # Default methods, for when Ch is not subclassed
+
+ def compute_r(self):
+ """Default method for objects that just contain a number or ndarray"""
+ return self.x
+
+ def compute_dr_wrt(self,wrt):
+ """Default method for objects that just contain a number or ndarray"""
+ if wrt is self: # special base case
+ return sp.eye(self.x.size, self.x.size)
+ #return np.array([[1]])
+ return None
+
+
+ def _compute_dr_wrt_sliced(self, wrt):
+ self._call_on_changed()
+
+ # if wrt is self:
+ # return np.array([[1]])
+
+ result = self.compute_dr_wrt(wrt)
+ if result is not None:
+ return result
+
+ # What allows slicing.
+ if True:
+ inner = wrt
+ while issubclass(inner.__class__, Permute):
+ inner = inner.a
+ if inner is self:
+ return None
+ result = self.compute_dr_wrt(inner)
+
+ if result is not None:
+ break
+
+ if result is None:
+ return None
+
+ wrt._call_on_changed()
+
+ jac = wrt.compute_dr_wrt(inner).T
+
+ return self._superdot(result, jac)
+
+
+ @property
+ def shape(self):
+ return self.r.shape
+
+ @property
+ def size(self):
+ #return self.r.size
+ return np.prod(self.shape) # may be cheaper since it doesn't always mean grabbing "r"
+
+ def __len__(self):
+ return len(self.r)
+
+ def minimize(self, *args, **kwargs):
+ from . import optimization
+ return optimization.minimize(self, *args, **kwargs)
+
+ def __array__(self, *args):
+ return self.r
+
+ ########################################################
+ # State management
+
+ def add_dterm(self, dterm_name, dterm):
+ self.dterms = list(set(list(self.dterms) + [dterm_name]))
+ setattr(self, dterm_name, dterm)
+
+ def copy(self):
+ return copy(self)
+
+ def __getstate__(self):
+ # Have to get rid of WeakKeyDictionaries for serialization
+ result = external_copy.copy(self.__dict__)
+ del result['_parents']
+ del result['_cache']
+ return result
+
+ def __setstate__(self, d):
+ # Restore unpickleable WeakKeyDictionaries
+ d['_parents'] = weakref.WeakKeyDictionary()
+ d['_cache'] = {'r': None, 'drs': weakref.WeakKeyDictionary()}
+ object.__setattr__(self, '__dict__', d)
+
+ # This restores our unpickleable "_parents" attribute
+ for k in set(self.dterms).intersection(set(self.__dict__.keys())):
+ setattr(self, k, self.__dict__[k])
+
+ def __setattr__(self, name, value, itr=None):
+ #print 'SETTING %s' % (name,)
+
+ # Faster path for basic Ch objects. Not necessary for functionality,
+ # but improves performance by a small amount.
+ if type(self) == Ch:
+ if name == 'x':
+ self._dirty_vars.add(name)
+ self.clear_cache(itr)
+ #else:
+ # import warnings
+ # warnings.warn('Trying to set attribute %s on a basic Ch object? Might be a mistake.' % (name,))
+
+ object.__setattr__(self, name, value)
+ return
+
+ name_in_dterms = name in self.dterms
+ name_in_terms = name in self.terms
+ name_in_props = name in _props_for(self.__class__)# [p[0] for p in inspect.getmembers(self.__class__, lambda x : isinstance(x, property))]
+
+ if name_in_dterms and not name_in_props and type(self) != Ch:
+ if not hasattr(value, 'dterms'):
+ value = Ch(value)
+
+ # Make ourselves not the parent of the old value
+ if hasattr(self, name):
+ term = getattr(self, name)
+ if self in term._parents:
+ term._parents[self]['varnames'].remove(name)
+ if len(term._parents[self]['varnames']) == 0:
+ del term._parents[self]
+
+ # Make ourselves parents of the new value
+ if self not in value._parents:
+ value._parents[self] = {'varnames': set([name])}
+ else:
+ value._parents[self]['varnames'].add(name)
+
+ if name_in_dterms or name_in_terms:
+ self._dirty_vars.add(name)
+ self._invalidate_cacheprop_names([name])
+
+ # If one of our terms has changed, it has the capacity to have
+ # changed our result and all our derivatives wrt everything
+ self.clear_cache(itr)
+
+ object.__setattr__(self, name, value)
+
+ def _invalidate_cacheprop_names(self, names):
+ nameset = set(names)
+ for func_name, v in list(self._depends_on_deps.items()):
+ if len(nameset.intersection(v['deps'])) > 0:
+ v['out_of_date'] = True
+
+
+ def clear_cache(self, itr=None):
+ todo = [self]
+ done = set([])
+ nodes_visited = 0
+ while len(todo) > 0:
+ nodes_visited += 1
+ next = todo.pop()
+ if itr is not None and itr==next._itr:
+ continue
+ if id(next) not in done:
+ next._cache['r'] = None
+ next._cache['drs'].clear()
+ next._itr = itr
+
+ for parent, parent_dict in list(next._parents.items()):
+ object.__setattr__(parent, '_dirty_vars', parent._dirty_vars.union(parent_dict['varnames']))
+ parent._invalidate_cacheprop_names(parent_dict['varnames'])
+ todo.append(parent)
+ done.add(id(next))
+ return nodes_visited
+
+
+ def clear_cache_wrt(self, wrt, itr=None):
+ if wrt in self._cache['drs']:
+ self._cache['drs'][wrt] = None
+
+ if hasattr(self, 'dr_cached') and wrt in self.dr_cached:
+ self.dr_cached[wrt] = None
+
+ if itr is None or itr != self._itr:
+ for parent, parent_dict in list(self._parents.items()):
+ if wrt in parent._cache['drs'] or (hasattr(parent, 'dr_cached') and wrt in parent.dr_cached):
+ parent.clear_cache_wrt(wrt=wrt, itr=itr)
+ object.__setattr__(parent, '_dirty_vars', parent._dirty_vars.union(parent_dict['varnames']))
+ parent._invalidate_cacheprop_names(parent_dict['varnames'])
+
+ object.__setattr__(self, '_itr', itr)
+
+ def replace(self, old, new):
+ if (hasattr(old, 'dterms') != hasattr(new, 'dterms')):
+ raise Exception('Either "old" and "new" must both be "Ch", or they must both be neither.')
+
+ for term_name in [t for t in list(self.dterms)+list(self.terms) if hasattr(self, t)]:
+ term = getattr(self, term_name)
+ if term is old:
+ setattr(self, term_name, new)
+ elif hasattr(term, 'dterms'):
+ term.replace(old, new)
+ return new
+
+
+ def set(self, **kwargs):
+ # Some dterms may be aliases via @property.
+ # We want to set those last, in case they access non-property members
+ #props = [p[0] for p in inspect.getmembers(self.__class__, lambda x : isinstance(x, property))]
+ props = _props_for(self.__class__)
+ kwarg_keys = set(kwargs.keys())
+ kwsecond = kwarg_keys.intersection(props)
+ kwfirst = kwarg_keys.difference(kwsecond)
+ kwall = list(kwfirst) + list(kwsecond)
+
+ # The complexity here comes because we wish to
+ # avoid clearing cache redundantly
+ if len(kwall) > 0:
+ for k in kwall[:-1]:
+ self.__setattr__(k, kwargs[k], 9999)
+ self.__setattr__(kwall[-1], kwargs[kwall[-1]], None)
+
+
+ def is_dr_wrt(self, wrt):
+ if type(self) == Ch:
+ return wrt is self
+ dterms_we_have = [getattr(self, dterm) for dterm in self.dterms if hasattr(self, dterm)]
+ return wrt in dterms_we_have or any([d.is_dr_wrt(wrt) for d in dterms_we_have])
+
+
+ def is_ch_baseclass(self):
+ return self.__class__ is Ch
+
+
+ ########################################################
+ # Getters for our outputs
+
+ def __getitem__(self, key):
+ shape = self.shape
+ tmp = np.arange(np.prod(shape)).reshape(shape).__getitem__(key)
+ idxs = tmp.ravel()
+ newshape = tmp.shape
+ return Select(a=self, idxs=idxs, preferred_shape=newshape)
+
+ def __setitem__(self, key, value, itr=None):
+
+ if hasattr(value, 'dterms'):
+ raise Exception("Can't assign a Ch objects as a subset of another.")
+ if type(self) == Ch:# self.is_ch_baseclass():
+ data = np.atleast_1d(self.x)
+ data.__setitem__(key, value)
+ self.__setattr__('x', data, itr=itr)
+ return
+ # elif False: # Interesting but flawed idea
+ # parents = [self.__dict__[k] for k in self.dterms]
+ # kids = []
+ # while len(parents)>0:
+ # p = parents.pop()
+ # if p.is_ch_baseclass():
+ # kids.append(p)
+ # else:
+ # parents += [p.__dict__[k] for k in p.dterms]
+ # from ch.optimization import minimize_dogleg
+ # minimize_dogleg(obj=self.__getitem__(key) - value, free_variables=kids, show_residuals=False)
+ else:
+ inner = self
+ while not inner.is_ch_baseclass():
+ if issubclass(inner.__class__, Permute):
+ inner = inner.a
+ else:
+ raise Exception("Can't set array that is function of arrays.")
+
+ self = self[key]
+ dr = self.dr_wrt(inner)
+ dr_rev = dr.T
+ #dr_rev = np.linalg.pinv(dr)
+ inner_shape = inner.shape
+
+ t1 = self._superdot(dr_rev, np.asarray(value).ravel())
+ t2 = self._superdot(dr_rev, self._superdot(dr, inner.x.ravel()))
+ if sp.issparse(t1): t1 = np.array(t1.todense())
+ if sp.issparse(t2): t2 = np.array(t2.todense())
+
+ inner.x = inner.x + t1.reshape(inner_shape) - t2.reshape(inner_shape)
+ #inner.x = inner.x + self._superdot(dr_rev, value.ravel()).reshape(inner_shape) - self._superdot(dr_rev, self._superdot(dr, inner.x.ravel())).reshape(inner_shape)
+
+
+ def __str__(self):
+ return str(self.r)
+
+ def __repr__(self):
+ return object.__repr__(self) + '\n' + str(self.r)
+
+ def __float__(self):
+ return self.r.__float__()
+
+ def __int__(self):
+ return self.r.__int__()
+
+ def on_changed(self, terms):
+ pass
+
+ @property
+ def T(self):
+ return transpose(self)
+
+ def transpose(self, *axes):
+ return transpose(self, *axes)
+
+ def squeeze(self, axis=None):
+ return squeeze(self, axis)
+
+ def mean(self, axis=None):
+ return mean(self, axis=axis)
+
+ def sum(self, axis=None):
+ return sum(self, axis=axis)
+
+ def _call_on_changed(self):
+
+ if hasattr(self, 'is_valid'):
+ validity, msg = self.is_valid()
+ assert validity, msg
+ if hasattr(self, '_status'):
+ self._status = 'new'
+
+ if len(self._dirty_vars) > 0:
+ self.on_changed(self._dirty_vars)
+ object.__setattr__(self, '_dirty_vars', set())
+
+ @property
+ def r(self):
+ self._call_on_changed()
+ if self._cache['r'] is None:
+ self._cache['r'] = np.asarray(np.atleast_1d(self.compute_r()), dtype=np.float64, order='C')
+ self._cache['rview'] = self._cache['r'].view()
+ self._cache['rview'].flags.writeable = False
+
+ return self._cache['rview']
+
+ def _superdot(self, lhs, rhs, profiler=None):
+
+ try:
+ if lhs is None:
+ return None
+ if rhs is None:
+ return None
+
+ if isinstance(lhs, np.ndarray) and lhs.size==1:
+ lhs = lhs.ravel()[0]
+
+ if isinstance(rhs, np.ndarray) and rhs.size==1:
+ rhs = rhs.ravel()[0]
+
+ if isinstance(lhs, numbers.Number) or isinstance(rhs, numbers.Number):
+ return lhs * rhs
+
+ if isinstance(rhs, LinearOperator):
+ return LinearOperator((lhs.shape[0], rhs.shape[1]), lambda x : lhs.dot(rhs.dot(x)))
+
+ if isinstance(lhs, LinearOperator):
+ if sp.issparse(rhs):
+ return LinearOperator((lhs.shape[0], rhs.shape[1]), lambda x : lhs.dot(rhs.dot(x)))
+ else:
+ # TODO: ?????????????
+ # return lhs.matmat(rhs)
+ return lhs.dot(rhs)
+
+ # TODO: Figure out how/whether to do this.
+ tm_maybe_sparse = timer()
+ lhs, rhs = convert_inputs_to_sparse_if_necessary(lhs, rhs)
+ if tm_maybe_sparse() > 0.1:
+ pif('convert_inputs_to_sparse_if_necessary in {}sec'.format(tm_maybe_sparse()))
+
+ if not sp.issparse(lhs) and sp.issparse(rhs):
+ return rhs.T.dot(lhs.T).T
+ return lhs.dot(rhs)
+ except Exception as e:
+ import sys, traceback
+ traceback.print_exc(file=sys.stdout)
+ if DEBUG:
+ import pdb; pdb.post_mortem()
+ else:
+ raise
+
+ def lmult_wrt(self, lhs, wrt):
+ if lhs is None:
+ return None
+
+ self._call_on_changed()
+
+ drs = []
+
+ direct_dr = self._compute_dr_wrt_sliced(wrt)
+
+ if direct_dr != None:
+ drs.append(self._superdot(lhs, direct_dr))
+
+ for k in set(self.dterms):
+ p = self.__dict__[k]
+
+ if hasattr(p, 'dterms') and p is not wrt and p.is_dr_wrt(wrt):
+ if not isinstance(p, Ch):
+ print('BROKEN!')
+ raise Exception('Broken Should be Ch object')
+
+ indirect_dr = p.lmult_wrt(self._superdot(lhs, self._compute_dr_wrt_sliced(p)), wrt)
+ if indirect_dr is not None:
+ drs.append(indirect_dr)
+
+ if len(drs)==0:
+ result = None
+
+ elif len(drs)==1:
+ result = drs[0]
+
+ else:
+ result = reduce(lambda x, y: x+y, drs)
+
+ return result
+
+
+ def compute_lop(self, wrt, lhs):
+ dr = self._compute_dr_wrt_sliced(wrt)
+ if dr is None: return None
+ return self._superdot(lhs, dr) if not isinstance(lhs, LinearOperator) else lhs.matmat(dr)
+
+
+ def lop(self, wrt, lhs):
+ self._call_on_changed()
+
+ drs = []
+ direct_dr = self.compute_lop(wrt, lhs)
+ if direct_dr is not None:
+ drs.append(direct_dr)
+
+ for k in set(self.dterms):
+ p = getattr(self, k) # self.__dict__[k]
+ if hasattr(p, 'dterms') and p is not wrt: # and p.is_dr_wrt(wrt):
+ lhs_for_child = self.compute_lop(p, lhs)
+ if lhs_for_child is not None: # Can be None with ChLambda, _result etc
+ indirect_dr = p.lop(wrt, lhs_for_child)
+ if indirect_dr is not None:
+ drs.append(indirect_dr)
+
+ for k in range(len(drs)):
+ if sp.issparse(drs[k]):
+ drs[k] = drs[k].todense()
+
+ if len(drs)==0:
+ result = None
+
+ elif len(drs)==1:
+ result = drs[0]
+
+ else:
+ result = reduce(lambda x, y: x+y, drs)
+
+
+ return result
+
+ def compute_rop(self, wrt, rhs):
+ dr = self._compute_dr_wrt_sliced(wrt)
+ if dr is None: return None
+
+ return self._superdot(dr, rhs)
+
+ def dr_wrt(self, wrt, reverse_mode=False, profiler=None):
+ tm_dr_wrt = timer()
+ self.called_dr_wrt = True
+ self._call_on_changed()
+
+ drs = []
+
+ if wrt in self._cache['drs']:
+ if DEBUG:
+ if wrt not in self._cache_info:
+ self._cache_info[wrt] = 0
+ self._cache_info[wrt] +=1
+ self._status = 'cached'
+ return self._cache['drs'][wrt]
+
+ direct_dr = self._compute_dr_wrt_sliced(wrt)
+
+ if direct_dr is not None:
+ drs.append(direct_dr)
+
+ if DEBUG:
+ self._status = 'pending'
+
+ propnames = set(_props_for(self.__class__))
+ for k in set(self.dterms).intersection(propnames.union(set(self.__dict__.keys()))):
+
+ p = getattr(self, k)
+
+ if hasattr(p, 'dterms') and p is not wrt:
+
+ indirect_dr = None
+
+ if reverse_mode:
+ lhs = self._compute_dr_wrt_sliced(p)
+ if isinstance(lhs, LinearOperator):
+ tm_dr_wrt.pause()
+ dr2 = p.dr_wrt(wrt)
+ tm_dr_wrt.resume()
+ indirect_dr = lhs.matmat(dr2) if dr2 != None else None
+ else:
+ indirect_dr = p.lmult_wrt(lhs, wrt)
+ else: # forward mode
+ tm_dr_wrt.pause()
+ dr2 = p.dr_wrt(wrt, profiler=profiler)
+ tm_dr_wrt.resume()
+ if dr2 is not None:
+ indirect_dr = self.compute_rop(p, rhs=dr2)
+
+ if indirect_dr is not None:
+ drs.append(indirect_dr)
+
+ if len(drs)==0:
+ result = None
+ elif len(drs)==1:
+ result = drs[0]
+ else:
+ # TODO: ????????
+ # result = np.sum(x for x in drs)
+ if not np.any([isinstance(a, LinearOperator) for a in drs]):
+ result = reduce(lambda x, y: x+y, drs)
+ else:
+ result = LinearOperator(drs[0].shape, lambda x : reduce(lambda a, b: a.dot(x)+b.dot(x),drs))
+
+ # TODO: figure out how/whether to do this.
+ if result is not None and not sp.issparse(result):
+ tm_nonzero = timer()
+ nonzero = np.count_nonzero(result)
+ if tm_nonzero() > 0.1:
+ pif('count_nonzero in {}sec'.format(tm_nonzero()))
+ if nonzero == 0 or hasattr(result, 'size') and result.size / float(nonzero) >= 10.0:
+ tm_convert_to_sparse = timer()
+ result = sp.csc_matrix(result)
+ import gc
+ gc.collect()
+ pif('converting result to sparse in {}sec'.format(tm_convert_to_sparse()))
+
+ if (result is not None) and (not sp.issparse(result)) and (not isinstance(result, LinearOperator)):
+ result = np.atleast_2d(result)
+
+ # When the number of parents is one, it indicates that
+ # caching this is probably not useful because not
+ # more than one parent will likely ask for this same
+ # thing again in the same iteration of an optimization.
+ #
+ # When the number of parents is zero, this is the top
+ # level object and should be cached; when it's > 1
+ # cache the combinations of the children.
+ #
+ # If we *always* filled in the cache, it would require
+ # more memory but would occasionally save a little cpu,
+ # on average.
+ if len(list(self._parents.keys())) != 1:
+ self._cache['drs'][wrt] = result
+
+ if DEBUG:
+ self._status = 'done'
+
+ if getattr(self, '_make_dense', False) and sp.issparse(result):
+ result = result.todense()
+ if getattr(self, '_make_sparse', False) and not sp.issparse(result):
+ result = sp.csc_matrix(result)
+
+ if tm_dr_wrt() > 0.1:
+ pif('dx of {} wrt {} in {}sec, sparse: {}'.format(self.short_name, wrt.short_name, tm_dr_wrt(), sp.issparse(result)))
+
+ return result
+
+
+ def __call__(self, **kwargs):
+ self.set(**kwargs)
+ return self.r
+
+
+ ########################################################
+ # Visualization
+
+ @property
+ def reset_flag(self):
+ """
+ Used as fn in loop_children_do
+ """
+ return lambda x: setattr(x, 'called_dr_wrt', False)
+
+ def loop_children_do(self, fn):
+ fn(self)
+ for dterm in self.dterms:
+ if hasattr(self, dterm):
+ dtval = getattr(self, dterm)
+ if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
+ if hasattr(dtval, 'loop_children_do'):
+ dtval.loop_children_do(fn)
+
+
+ def show_tree_cache(self, label, current_node=None):
+ '''
+ Show tree and cache info with color represent _status
+ Optionally accpet current_node arg to highlight the current node we are in
+ '''
+ import os
+ import tempfile
+ import subprocess
+
+ assert DEBUG, "Please use dr tree visualization functions in debug mode"
+
+ cache_path = os.path.abspath('profiles')
+ def string_for(self, my_name):
+
+ color_mapping = {'new' : 'grey', 'pending':'red', 'cached':'yellow', 'done': 'green'}
+ if hasattr(self, 'label'):
+ my_name = self.label
+ my_name = '%s (%s)' % (my_name, str(self.__class__.__name__))
+ result = []
+ if not hasattr(self, 'dterms'):
+ return result
+ for dterm in self.dterms:
+ if hasattr(self, dterm):
+ dtval = getattr(self, dterm)
+ if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
+ child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
+ child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
+ src = 'aaa%d' % (id(self))
+ dst = 'aaa%d' % (id(dtval))
+
+ s = ''
+ color = color_mapping[dtval._status] if hasattr(dtval, '_status') else 'grey'
+ if dtval == current_node:
+ color = 'blue'
+ if isinstance(dtval, Concatenate) and len(dtval.dr_cached) > 0:
+ s = 'dr_cached\n'
+ for k, v in dtval.dr_cached.items():
+ if v is not None:
+ issparse = sp.issparse(v)
+ size = v.size
+ if issparse:
+ size = v.shape[0] * v.shape[1]
+ nonzero = len(v.data)
+ else:
+ nonzero = np.count_nonzero(v)
+ s += '\nsparse: %s\nsize: %d\nnonzero: %d\n' % (issparse, size, nonzero)
+ # if dtval.called_dr_wrt:
+ # # dtval.called_dr_wrt = False
+ # color = 'brown3'
+ # else:
+ # color = 'azure1'
+ elif len(dtval._cache['drs']) > 0:
+ s = '_cache\n'
+
+ for k, v in dtval._cache['drs'].items():
+ if v is not None:
+ issparse = sp.issparse(v)
+ size = v.size
+ if issparse:
+ size = v.shape[0] * v.shape[1]
+ nonzero = len(v.data)
+ else:
+ nonzero = np.count_nonzero(v)
+
+ s += '\nsparse: %s\nsize: %d\nnonzero: %d\n' % (issparse, size, nonzero)
+ if hasattr(dtval, '_cache_info'):
+ s += '\ncache hit:%s\n' % dtval._cache_info[k]
+ # if hasattr(dtval,'called_dr_wrt') and dtval.called_dr_wrt:
+ # # dtval.called_dr_wrt = False
+ # color = 'brown3'
+ # else:
+ # color = 'azure1'
+ result += ['%s -> %s;' % (src, dst)]
+ # Do not overwrite src
+ #result += ['%s [label="%s"];' % (src, my_name)]
+ result += ['%s [label="%s\n%s\n", color=%s, style=filled];' %
+ (dst, child_label, s, color)]
+ result += string_for(getattr(self, dterm), dterm)
+ return result
+
+
+ dot_file_contents = 'digraph G {\n%s\n}' % '\n'.join(list(set(string_for(self, 'root'))))
+ dot_file_name = os.path.join(cache_path, label)
+ png_file_name = os.path.join(cache_path, label+'.png')
+ with open(dot_file_name, 'w') as dot_file:
+ with open(png_file_name, 'w') as png_file:
+ dot_file.write(dot_file_contents)
+ dot_file.flush()
+
+ png_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ subprocess.call(['dot', '-Tpng', '-o', png_file.name, dot_file.name])
+
+ import webbrowser
+ webbrowser.open('file://' + png_file.name)
+
+ self.loop_children_do(self.reset_flag)
+
+ def show_tree_wrt(self, wrt):
+ import tempfile
+ import subprocess
+
+ assert DEBUG, "Please use dr tree visualization functions in debug mode"
+
+ def string_for(self, my_name, wrt):
+ if hasattr(self, 'label'):
+ my_name = self.label
+ my_name = '%s (%s)' % (my_name, str(self.__class__.__name__))
+ result = []
+ if not hasattr(self, 'dterms'):
+ return result
+ for dterm in self.dterms:
+ if hasattr(self, dterm):
+ dtval = getattr(self, dterm)
+ if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
+ child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
+ child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
+ src = 'aaa%d' % (id(self))
+ dst = 'aaa%d' % (id(dtval))
+ result += ['%s -> %s;' % (src, dst)]
+ result += ['%s [label="%s"];' % (src, my_name)]
+ if wrt in dtval._cache['drs'] and dtval._cache['drs'][wrt] is not None:
+ issparse = sp.issparse(dtval._cache['drs'][wrt])
+ size = dtval._cache['drs'][wrt].size
+ nonzero = np.count_nonzero(dtval._cache['drs'][wrt])
+ result += ['%s [label="%s\n is_sparse: %s\nsize: %d\nnonzero: %d"];' %
+ (dst, child_label, issparse, size,
+ nonzero)]
+ else:
+ result += ['%s [label="%s"];' % (dst, child_label)]
+ result += string_for(getattr(self, dterm), dterm, wrt)
+ return result
+
+
+ dot_file_contents = 'digraph G {\n%s\n}' % '\n'.join(list(set(string_for(self, 'root', wrt))))
+ dot_file = tempfile.NamedTemporaryFile()
+ dot_file.write(dot_file_contents)
+ dot_file.flush()
+ png_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ subprocess.call(['dot', '-Tpng', '-o', png_file.name, dot_file.name])
+ import webbrowser
+ webbrowser.open('file://' + png_file.name)
+
+ def show_tree(self, cachelim=np.inf):
+ """Cachelim is in Mb. For any cached jacobians above cachelim, they are also added to the graph. """
+ import tempfile
+ import subprocess
+
+ assert DEBUG, "Please use dr tree visualization functions in debug mode"
+
+ def string_for(self, my_name):
+ if hasattr(self, 'label'):
+ my_name = self.label
+ my_name = '%s (%s)' % (my_name, str(self.__class__.__name__))
+ result = []
+ if not hasattr(self, 'dterms'):
+ return result
+ for dterm in self.dterms:
+ if hasattr(self, dterm):
+ dtval = getattr(self, dterm)
+ if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
+ child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
+ child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
+ src = 'aaa%d' % (id(self))
+ dst = 'aaa%d' % (id(dtval))
+ result += ['%s -> %s;' % (src, dst)]
+ result += ['%s [label="%s"];' % (src, my_name)]
+ result += ['%s [label="%s"];' % (dst, child_label)]
+ result += string_for(getattr(self, dterm), dterm)
+
+ if cachelim != np.inf and hasattr(self, '_cache') and 'drs' in self._cache:
+ from six.moves import cPickle as pickle
+ for dtval, jac in list(self._cache['drs'].items()):
+ # child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
+ # child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
+ src = 'aaa%d' % (id(self))
+ dst = 'aaa%d' % (id(dtval))
+ try:
+ sz = sys.getsizeof(pickle.dumps(jac, -1))
+ except: # some are functions
+ sz = 0
+ # colorattr = "#%02x%02x%02x" % (szpct*255, 0, (1-szpct)*255)
+ # print colorattr
+ if sz > (cachelim * 1024 * 1024):
+ result += ['%s -> %s [style=dotted,color="<<<%d>>>"];' % (src, dst, sz)]
+ #
+ # result += ['%s -> %s [style=dotted];' % (src, dst)]
+ # result += ['%s [label="%s"];' % (src, my_name)]
+ # result += ['%s [label="%s"];' % (dst, child_label)]
+ # result += string_for(getattr(self, dterm), dterm)
+
+ return result
+
+ dot_file_contents = 'digraph G {\n%s\n}' % '\n'.join(list(set(string_for(self, 'root'))))
+ if cachelim != np.inf:
+ import re
+ strs = re.findall(r'<<<(\d+)>>>', dot_file_contents, re.DOTALL)
+ if len(strs) > 0:
+ the_max = np.max(np.array([int(d) for d in strs]))
+ for s in strs:
+ szpct = float(s)/the_max
+ sz = float(s)
+ unit = 'b'
+ if sz > 1024.:
+ sz /= 1024
+ unit = 'K'
+ if sz > 1024.:
+ sz /= 1024
+ unit = 'M'
+ if sz > 1024.:
+ sz /= 1024
+ unit = 'G'
+ if sz > 1024.:
+ sz /= 1024
+ unit = 'T'
+
+ dot_file_contents = re.sub('<<<%s>>>' % s, '#%02x%02x%02x",label="%d%s' % (szpct*255, 0, (1-szpct)*255, sz, unit), dot_file_contents)
+
+ dot_file = tempfile.NamedTemporaryFile()
+ dot_file.write(dot_file_contents)
+ dot_file.flush()
+ png_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ subprocess.call(['dot', '-Tpng', '-o', png_file.name, dot_file.name])
+ import webbrowser
+ webbrowser.open('file://' + png_file.name)
+
+
+ def tree_iterator(self, visited=None, path=None):
+ '''
+ Generator function that traverse the dr tree start from this node (self).
+ '''
+ if visited is None:
+ visited = set()
+
+ if self not in visited:
+ if path and isinstance(path, list):
+ path.append(self)
+
+ visited.add(self)
+ yield self
+
+ if not hasattr(self, 'dterms'):
+ yield
+
+ for dterm in self.dterms:
+ if hasattr(self, dterm):
+ child = getattr(self, dterm)
+ if hasattr(child, 'dterms') or hasattr(child, 'terms'):
+ for node in child.tree_iterator(visited):
+ yield node
+
+ def floor(self):
+ return floor(self)
+
+ def ceil(self):
+ return ceil(self)
+
+ def dot(self, other):
+ return dot(self, other)
+
+ def cumsum(self, axis=None):
+ return cumsum(a=self, axis=axis)
+
+ def min(self, axis=None):
+ return amin(a=self, axis=axis)
+
+ def max(self, axis=None):
+ return amax(a=self, axis=axis)
+
+ ########################################################
+ # Operator overloads
+
+ def __pos__(self): return self
+ def __neg__(self): return negative(self)
+
+ def __add__ (self, other): return add(a=self, b=other)
+ def __radd__(self, other): return add(a=other, b=self)
+
+ def __sub__ (self, other): return subtract(a=self, b=other)
+ def __rsub__(self, other): return subtract(a=other, b=self)
+
+ def __mul__ (self, other): return multiply(a=self, b=other)
+ def __rmul__(self, other): return multiply(a=other, b=self)
+
+ def __div__ (self, other): return divide(x1=self, x2=other)
+ def __truediv__ (self, other): return divide(x1=self, x2=other)
+ def __rdiv__(self, other): return divide(x1=other, x2=self)
+
+ def __pow__ (self, other): return power(x=self, pow=other)
+ def __rpow__(self, other): return power(x=other, pow=self)
+
+ def __rand__(self, other): return self.__and__(other)
+
+ def __abs__ (self): return abs(self)
+
+ def __gt__(self, other): return greater(self, other)
+ def __ge__(self, other): return greater_equal(self, other)
+
+ def __lt__(self, other): return less(self, other)
+ def __le__(self, other): return less_equal(self, other)
+
+ def __ne__(self, other): return not_equal(self, other)
+
+ # not added yet because of weak key dict conflicts
+ #def __eq__(self, other): return equal(self, other)
+
+
+Ch._reserved_kw = set(Ch.__dict__.keys())
+
+
+class MatVecMult(Ch):
+ terms = 'mtx'
+ dterms = 'vec'
+ def compute_r(self):
+ result = self.mtx.dot(col(self.vec.r.ravel())).ravel()
+ if len(self.vec.r.shape) > 1 and self.vec.r.shape[1] > 1:
+ result = result.reshape((-1,self.vec.r.shape[1]))
+ return result
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.vec:
+ return sp.csc_matrix(self.mtx)
+
+
+#def depends_on(*dependencies):
+# def _depends_on(func):
+# @wraps(func)
+# def with_caching(self, *args, **kwargs):
+# return func(self, *args, **kwargs)
+# return property(with_caching)
+# return _depends_on
+
+
+def depends_on(*dependencies):
+ deps = set()
+ for dep in dependencies:
+ if isinstance(dep, str):
+ deps.add(dep)
+ else:
+ [deps.add(d) for d in dep]
+
+ def _depends_on(func):
+ want_out = 'out' in inspect.getargspec(func).args
+
+ @wraps(func)
+ def with_caching(self, *args, **kwargs):
+ func_name = func.__name__
+ sdf = self._depends_on_deps[func_name]
+ if sdf['out_of_date'] == True:
+ #tm = time.time()
+ if want_out:
+ kwargs['out'] = sdf['value']
+ sdf['value'] = func(self, *args, **kwargs)
+ sdf['out_of_date'] = False
+ #print 'recomputed %s in %.2e' % (func_name, time.time() - tm)
+ return sdf['value']
+ with_caching.deps = deps # set(dependencies)
+ result = property(with_caching)
+ return result
+ return _depends_on
+
+
+
+class ChHandle(Ch):
+ dterms = ('x',)
+
+ def compute_r(self):
+ assert(self.x is not self)
+ return self.x.r
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.x:
+ return 1
+
+
+class ChLambda(Ch):
+ terms = ['lmb', 'initial_args']
+ dterms = []
+ term_order = ['lmb', 'initial_args']
+
+ def on_changed(self, which):
+ for argname in set(which).intersection(set(self.args.keys())):
+ self.args[argname].x = getattr(self, argname)
+
+ def __init__(self, lmb, initial_args=None):
+ args = {argname: ChHandle(x=Ch(idx)) for idx, argname in enumerate(inspect.getargspec(lmb)[0])}
+ if initial_args is not None:
+ for initial_arg in initial_args:
+ if initial_arg in args:
+ args[initial_arg].x = initial_args[initial_arg]
+ result = lmb(**args)
+ for argname, arg in list(args.items()):
+ if result.is_dr_wrt(arg.x):
+ self.add_dterm(argname, arg.x)
+ else:
+ self.terms.append(argname)
+ setattr(self, argname, arg.x)
+ self.args = args
+ self.add_dterm('_result', result)
+
+ def __getstate__(self):
+ # Have to get rid of lambda for serialization
+ if hasattr(self, 'lmb'):
+ self.lmb = None
+ return super(self.__class__, self).__getstate__()
+
+
+ def compute_r(self):
+ return self._result.r
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self._result:
+ return 1
+
+# ChGroup is similar to ChLambda in that it's designed to expose the "internal"
+# inputs of result but, unlike ChLambda, result is kept internal and called when
+# compute_r and compute_dr_wrt is called to compute the relevant Jacobians.
+# This provides a way of effectively applying the chain rule in a different order.
+class ChGroup(Ch):
+ terms = ['result', 'args']
+ dterms = []
+ term_order = ['result', 'args']
+
+ def on_changed(self, which):
+ for argname in set(which).intersection(set(self.args.keys())):
+ if not self.args[argname].x is getattr(self, argname) :
+ self.args[argname].x = getattr(self, argname)
+
+ # right now the entries in args have to refer to terms/dterms of result,
+ # it would be better if they could be "internal" as well, but for now the idea
+ # is that result may itself be a ChLambda.
+ def __init__(self, result, args):
+ self.args = { argname: ChHandle(x=arg) for argname, arg in list(args.items()) }
+ for argname, arg in list(self.args.items()):
+ setattr(result, argname, arg)
+ if result.is_dr_wrt(arg.x):
+ self.add_dterm(argname, arg.x)
+ else:
+ self.terms.append(argname)
+ setattr(self, argname, arg.x)
+ self._result = result
+
+ def compute_r(self):
+ return self._result.r
+
+ def compute_dr_wrt(self, wrt):
+ return self._result.dr_wrt(wrt)
+
+from .ch_ops import *
+from .ch_ops import __all__ as all_ch_ops
+__all__ += all_ch_ops
+
+from .reordering import *
+from .reordering import Permute
+from .reordering import __all__ as all_reordering
+__all__ += all_reordering
+
+
+from . import linalg
+from . import ch_random as random
+__all__ += ['linalg', 'random']
+
+
+
+
+
+class tst(Ch):
+ dterms = ['a', 'b', 'c']
+
+ def compute_r(self):
+ return self.a.r + self.b.r + self.c.r
+
+ def compute_dr_wrt(self, wrt):
+ return 1
+
+def main():
+ foo = tst
+
+ x10 = Ch(10)
+ x20 = Ch(20)
+ x30 = Ch(30)
+
+ tmp = ChLambda(lambda x, y, z: Ch(1) + Ch(2) * Ch(3) + 4)
+ print(tmp.dr_wrt(tmp.x))
+ import pdb; pdb.set_trace()
+ #a(b(c(d(e(f),g),h)))
+
+ blah = tst(x10, x20, x30)
+
+ print(blah.r)
+
+
+ print(foo)
+
+ import pdb; pdb.set_trace()
+
+ # import unittest
+ # from test_ch import TestCh
+ # suite = unittest.TestLoader().loadTestsFromTestCase(TestCh)
+ # unittest.TextTestRunner(verbosity=2).run(suite)
+
+
+
+if __name__ == '__main__':
+ main()
+
diff --git a/chumpy-0.70/chumpy/ch_ops.py b/chumpy-0.70/chumpy/ch_ops.py
new file mode 100755
index 00000000..f5f2f744
--- /dev/null
+++ b/chumpy-0.70/chumpy/ch_ops.py
@@ -0,0 +1,814 @@
+#!/usr/bin/env python
+# encoding: utf-8
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+# Numpy functions
+__all__ = ['array', 'amax','amin', 'max', 'min', 'maximum','minimum','nanmax','nanmin',
+ 'sum', 'exp', 'log', 'mean','std', 'var',
+ 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
+ 'sqrt', 'square', 'absolute', 'abs', 'clip',
+ 'power',
+ 'add', 'divide', 'multiply', 'negative', 'subtract', 'reciprocal',
+ 'nan_to_num',
+ 'dot', 'cumsum',
+ 'floor', 'ceil',
+ 'greater', 'greater_equal', 'less', 'less_equal', 'equal', 'not_equal',
+ 'nonzero', 'ascontiguousarray', 'asfarray', 'arange', 'asarray', 'copy',
+ 'cross',
+ 'shape', 'sign']
+
+
+__all__ += ['SumOfSquares',
+ 'NanDivide', ]
+
+
+# These can be wrapped directly as Ch(routine(*args, **kwargs)),
+# so that for example "ch.eye(3)" translates into Ch(np.eye(3))
+numpy_array_creation_routines = [
+ 'empty','empty_like','eye','identity','ones','ones_like','zeros','zeros_like',
+ 'array',
+ 'arange','linspace','logspace','meshgrid','mgrid','ogrid',
+ 'fromfunction', 'fromiter', 'meshgrid', 'tri'
+]
+
+
+wont_implement = ['asanyarray', 'asmatrix', 'frombuffer', 'copy', 'fromfile', 'fromstring', 'loadtxt', 'copyto', 'asmatrix', 'asfortranarray', 'asscalar', 'require']
+not_yet_implemented = ['tril', 'triu', 'vander']
+
+__all__ += not_yet_implemented
+__all__ += wont_implement
+__all__ += numpy_array_creation_routines
+
+
+from .ch import Ch
+import six
+import numpy as np
+import warnings
+from six.moves import cPickle as pickle
+import scipy.sparse as sp
+from .utils import row, col
+from copy import copy as copy_copy
+from functools import reduce
+
+__all__ += ['pi', 'set_printoptions']
+pi = np.pi
+set_printoptions = np.set_printoptions
+arange = np.arange
+
+for rtn in ['argmax', 'nanargmax', 'argmin', 'nanargmin']:
+ exec('def %s(a, axis=None) : return np.%s(a.r, axis) if hasattr(a, "compute_r") else np.%s(a, axis)' % (rtn, rtn, rtn))
+ __all__ += [rtn]
+
+for rtn in ['argwhere', 'nonzero', 'flatnonzero']:
+ exec('def %s(a) : return np.%s(a.r) if hasattr(a, "compute_r") else np.%s(a)' % (rtn, rtn, rtn))
+ __all__ += [rtn]
+
+for rtn in numpy_array_creation_routines:
+ exec('def %s(*args, **kwargs) : return Ch(np.%s(*args, **kwargs))' % (rtn, rtn))
+
+
+class WontImplement(Exception):
+ pass
+
+for rtn in wont_implement:
+ exec('def %s(*args, **kwargs) : raise WontImplement' % (rtn))
+
+for rtn in not_yet_implemented:
+ exec('def %s(*args, **kwargs) : raise NotImplementedError' % (rtn))
+
+def asarray(a, dtype=None, order=None):
+ assert(dtype is None or dtype is np.float64)
+ assert(order is 'C' or order is None)
+ if hasattr(a, 'dterms'):
+ return a
+ return Ch(np.asarray(a, dtype, order))
+
+# Everythign is always c-contiguous
+def ascontiguousarray(a, dtype=None): return a
+
+# Everything is always float
+asfarray = ascontiguousarray
+
+def copy(self):
+ return pickle.loads(pickle.dumps(self))
+
+def asfortranarray(a, dtype=None): raise WontImplement
+
+
+class Simpleton(Ch):
+ dterms = 'x'
+ def compute_dr_wrt(self, wrt):
+ return None
+
+class floor(Simpleton):
+ def compute_r(self): return np.floor(self.x.r)
+
+class ceil(Simpleton):
+ def compute_r(self): return np.ceil(self.x.r)
+
+class sign(Simpleton):
+ def compute_r(self): return np.sign(self.x.r)
+
+class Cross(Ch):
+ dterms = 'a', 'b'
+ terms = 'axisa', 'axisb', 'axisc', 'axis'
+ term_order = 'a', 'b', 'axisa', 'axisb', 'axisc', 'axis'
+
+ def compute_r(self):
+ return np.cross(self.a.r, self.b.r, self.axisa, self.axisb, self.axisc, self.axis)
+
+
+ def _load_crossprod_cache(self, h, w):
+ if not hasattr(self, '_w'):
+ self._w = 0
+ self._h = 0
+
+ if h!=self._h or w!=self._w:
+ sz = h*w
+ rng = np.arange(sz)
+ self._JS = np.repeat(rng.reshape((-1,w)), w, axis=0).ravel()
+ self._IS = np.repeat(rng, w)
+ self._tiled_identity = np.tile(np.eye(w), (h, 1))
+ self._h = h
+ self._w = w
+
+ return self._tiled_identity, self._IS, self._JS,
+
+
+
+ # Could be at least 2x faster, with some work
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.a and wrt is not self.b:
+ return
+
+ sz = self.a.size
+ h, w = self.a.shape
+ tiled_identity, IS, JS = self._load_crossprod_cache(h, w)
+
+ #import time
+ #tm = time.time()
+ if wrt is self.a:
+ rp = np.repeat(-self.b.r, w, axis=0)
+ result = np.cross(
+ tiled_identity,
+ rp,
+ self.axisa,
+ self.axisb,
+ self.axisc,
+ self.axis)
+
+ elif wrt is self.b:
+ result = np.cross(
+ np.repeat(-self.a.r, w, axis=0),
+ tiled_identity,
+ self.axisa,
+ self.axisb,
+ self.axisc,
+ self.axis)
+
+ # rng = np.arange(sz)
+ # JS = np.repeat(rng.reshape((-1,w)), w, axis=0).ravel()
+ # IS = np.repeat(rng, w)
+ data = result.ravel()
+ result = sp.csc_matrix((data, (IS,JS)), shape=(self.size, wrt.size))
+ #import pdb; pdb.set_trace()
+ #print 'B TOOK %es' % (time.time() -tm )
+ return result
+
+def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
+ return Cross(a, b, axisa, axisb, axisc, axis)
+
+
+
+
+class cumsum(Ch):
+ dterms = 'a'
+ terms = 'axis'
+ term_order = 'a', 'axis'
+
+ def on_changed(self, which):
+ if not hasattr(self, 'axis'):
+ self.axis = None
+
+ def compute_r(self):
+ return np.cumsum(self.a.r, axis=self.axis)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.a:
+ return None
+
+ if self.axis is not None:
+ raise NotImplementedError
+
+ IS = np.tile(row(np.arange(self.a.size)), (self.a.size, 1))
+ JS = IS.T
+ IS = IS.ravel()
+ JS = JS.ravel()
+ which = IS >= JS
+ IS = IS[which]
+ JS = JS[which]
+ data = np.ones_like(IS)
+ result = sp.csc_matrix((data, (IS, JS)), shape=(self.a.size, self.a.size))
+ return result
+
+
+class UnaryElemwise(Ch):
+ dterms = 'x'
+
+ def compute_r(self):
+ return self._r(self.x.r)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.x:
+ result = self._d(self.x.r)
+ return sp.diags([result.ravel()], [0]) if len(result)>1 else np.atleast_2d(result)
+
+
+class nan_to_num(UnaryElemwise):
+ _r = lambda self, x : np.nan_to_num(x)
+ _d = lambda self, x : np.asarray(np.isfinite(x), np.float64)
+
+class reciprocal(UnaryElemwise):
+ _r = np.reciprocal
+ _d = lambda self, x : -np.reciprocal(np.square(x))
+
+class square(UnaryElemwise):
+ _r = np.square
+ _d = lambda self, x : x * 2.
+
+def my_power(a, b):
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore",category=RuntimeWarning)
+ return np.nan_to_num(np.power(a, b))
+
+class sqrt(UnaryElemwise):
+ _r = np.sqrt
+ _d = lambda self, x : .5 * my_power(x, -0.5)
+
+class exp(UnaryElemwise):
+ _r = np.exp
+ _d = np.exp
+
+class log(UnaryElemwise):
+ _r = np.log
+ _d = np.reciprocal
+
+class sin(UnaryElemwise):
+ _r = np.sin
+ _d = np.cos
+
+class arcsin(UnaryElemwise):
+ _r = np.arcsin
+ _d = lambda self, x : np.reciprocal(np.sqrt(1.-np.square(x)))
+
+class cos(UnaryElemwise):
+ _r = np.cos
+ _d = lambda self, x : -np.sin(x)
+
+class arccos(UnaryElemwise):
+ _r = np.arccos
+ _d = lambda self, x : -np.reciprocal(np.sqrt(1.-np.square(x)))
+
+class tan(UnaryElemwise):
+ _r = np.tan
+ _d = lambda self, x : np.reciprocal(np.cos(x)**2.)
+
+class arctan(UnaryElemwise):
+ _r = np.arctan
+ _d = lambda self, x : np.reciprocal(np.square(x)+1.)
+
+class negative(UnaryElemwise):
+ _r = np.negative
+ _d = lambda self, x : np.negative(np.ones_like(x))
+
+class absolute(UnaryElemwise):
+ _r = np.abs
+ _d = lambda self, x : (x>0)*2-1.
+
+abs = absolute
+
+class clip(Ch):
+ dterms = 'a'
+ terms = 'a_min', 'a_max'
+ term_order = 'a', 'a_min', 'a_max'
+
+ def compute_r(self):
+ return np.clip(self.a.r, self.a_min, self.a_max)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.a:
+ result = np.asarray((self.r != self.a_min) & (self.r != self.a_max), np.float64)
+ return sp.diags([result.ravel()], [0]) if len(result)>1 else np.atleast_2d(result)
+
+class sum(Ch):
+ dterms = 'x',
+ terms = 'axis',
+ term_order = 'x', 'axis'
+
+ def on_changed(self, which):
+ if not hasattr(self, 'axis'):
+ self.axis = None
+ if not hasattr(self, 'dr_cache'):
+ self.dr_cache = {}
+
+ def compute_r(self):
+ return np.sum(self.x.r, axis=self.axis)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.x:
+ return
+ if self.axis == None:
+ return row(np.ones((1, len(self.x.r.ravel()))))
+ else:
+ uid = tuple(list(self.x.shape) + [self.axis])
+ if uid not in self.dr_cache:
+ idxs_presum = np.arange(self.x.size).reshape(self.x.shape)
+ idxs_presum = np.rollaxis(idxs_presum, self.axis, 0)
+ idxs_postsum = np.arange(self.r.size).reshape(self.r.shape)
+ tp = np.ones(idxs_presum.ndim, dtype=np.uint32)
+ tp[0] = idxs_presum.shape[0]
+ idxs_postsum = np.tile(idxs_postsum, tp)
+ data = np.ones(idxs_postsum.size)
+ result = sp.csc_matrix((data, (idxs_postsum.ravel(), idxs_presum.ravel())), (self.r.size, wrt.size))
+ self.dr_cache[uid] = result
+ return self.dr_cache[uid]
+
+
+class mean(Ch):
+ dterms = 'x',
+ terms = 'axis',
+ term_order = 'x', 'axis'
+
+ def on_changed(self, which):
+ if not hasattr(self, 'axis'):
+ self.axis = None
+ if not hasattr(self, 'dr_cache'):
+ self.dr_cache = {}
+
+ def compute_r(self):
+ return np.array(np.mean(self.x.r, axis=self.axis))
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.x:
+ return
+ if self.axis == None:
+ return row(np.ones((1, len(self.x.r))))/len(self.x.r)
+ else:
+ uid = tuple(list(self.x.shape) + [self.axis])
+ if uid not in self.dr_cache:
+ idxs_presum = np.arange(self.x.size).reshape(self.x.shape)
+ idxs_presum = np.rollaxis(idxs_presum, self.axis, 0)
+ idxs_postsum = np.arange(self.r.size).reshape(self.r.shape)
+ tp = np.ones(idxs_presum.ndim, dtype=np.uint32)
+ tp[0] = idxs_presum.shape[0]
+ idxs_postsum = np.tile(idxs_postsum, tp)
+ data = np.ones(idxs_postsum.size) / self.x.shape[self.axis]
+ result = sp.csc_matrix((data, (idxs_postsum.ravel(), idxs_presum.ravel())), (self.r.size, wrt.size))
+ self.dr_cache[uid] = result
+ return self.dr_cache[uid]
+
+
+def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
+ if (dtype != None or out != None or ddof != 0 or keepdims != False):
+ raise NotImplementedException('Unimplemented for non-default dtype, out, ddof, and keepdims.')
+ return mean(a**2., axis=axis)
+
+def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
+ if (dtype != None or out != None or ddof != 0 or keepdims != False):
+ raise NotImplementedException('Unimplemented for non-default dtype, out, ddof, and keepdims.')
+ return sqrt(var(a, axis=axis))
+
+
+class SumOfSquares(Ch):
+ dterms = 'x',
+
+ def compute_r(self):
+ return np.sum(self.x.r.ravel()**2.)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.x:
+ return row(self.x.r.ravel()*2.)
+
+
+class divide (Ch):
+ dterms = 'x1', 'x2'
+
+ def compute_r(self):
+ return self.x1.r / self.x2.r
+
+ def compute_dr_wrt(self, wrt):
+
+ if (wrt is self.x1) == (wrt is self.x2):
+ return None
+
+ IS, JS, input_sz, output_sz = _broadcast_setup(self.x1, self.x2, wrt)
+
+ x1r, x2r = self.x1.r, self.x2.r
+ if wrt is self.x1:
+ data = (np.ones_like(x1r) / x2r).ravel()
+ else:
+ data = (-x1r / (x2r*x2r)).ravel()
+
+ return sp.csc_matrix((data, (IS, JS)), shape=(self.r.size, wrt.r.size))
+
+
+
+
+class NanDivide(divide):
+ dterms = 'x1', 'x2'
+
+ def compute_r(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ result = super(self.__class__, self).compute_r()
+ shape = result.shape
+ result = result.ravel()
+ result[np.isinf(result)] = 0
+ result[np.isnan(result)] = 0
+ return result.reshape(shape)
+
+ def compute_dr_wrt(self, wrt):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ result = super(self.__class__, self).compute_dr_wrt(wrt)
+ if result is not None:
+ result = result.copy()
+ if sp.issparse(result):
+ result.data[np.isinf(result.data)] = 0
+ result.data[np.isnan(result.data)] = 0
+ return result
+ else:
+ rr = result.ravel()
+ rr[np.isnan(rr)] = 0.
+ rr[np.isinf(rr)] = 0.
+ return result
+
+
+def shape(a):
+ return a.shape if hasattr(a, 'shape') else np.shape(a)
+
+
+_bs_setup_data1 = {}
+_bs_setup_data2 = {}
+def _broadcast_matrix(a, b, wrt, data):
+ global _bs_setup_data1, _bs_setup_data2
+
+ if len(set((a.shape, b.shape))) == 1:
+ uid = a.shape
+ if uid not in _bs_setup_data1:
+ asz = a.size
+ IS = np.arange(asz)
+ _bs_setup_data1[uid] = sp.csc_matrix((np.empty(asz), (IS, IS)), shape=(asz, asz))
+ result = copy_copy(_bs_setup_data1[uid])
+ if isinstance(data, np.ndarray):
+ result.data = data.ravel()
+ else: # assumed scalar
+ result.data = np.empty(result.nnz)
+ result.data.fill(data)
+ else:
+ uid = (a.shape, b.shape, wrt is a, wrt is b)
+ if uid not in _bs_setup_data2:
+ input_sz = wrt.size
+ output_sz = np.broadcast(a.r, b.r).size
+ a2 = np.arange(a.size).reshape(a.shape) if wrt is a else np.zeros(a.shape)
+ b2 = np.arange(b.size).reshape(b.shape) if (wrt is b and wrt is not a) else np.zeros(b.shape)
+ IS = np.arange(output_sz)
+ JS = np.asarray((np.add(a2,b2)).ravel(), np.uint32)
+
+ _bs_setup_data2[uid] = sp.csc_matrix((np.arange(IS.size), (IS, JS)), shape=(output_sz, input_sz))
+
+ result = copy_copy(_bs_setup_data2[uid])
+ if isinstance(data, np.ndarray):
+ result.data = data[result.data]
+ else: # assumed scalar
+ result.data = np.empty(result.nnz)
+ result.data.fill(data)
+
+ if np.prod(result.shape) == 1:
+ return np.array(data)
+ else:
+ return result
+
+
+
+
+broadcast_shape_cache = {}
+def broadcast_shape(a_shape, b_shape):
+ global broadcast_shape_cache
+
+ raise Exception('This function is probably a bad idea, because shape is not cached and overquerying can occur.')
+
+ uid = (a_shape, b_shape)
+
+ if uid not in broadcast_shape_cache:
+ la = len(a_shape)
+ lb = len(b_shape)
+ ln = la if la > lb else lb
+
+ ash = np.ones(ln, dtype=np.uint32)
+ bsh = np.ones(ln, dtype=np.uint32)
+ ash[-la:] = a_shape
+ bsh[-lb:] = b_shape
+
+ our_result = np.max(np.vstack((ash, bsh)), axis=0)
+
+ if False:
+ numpy_result = np.broadcast(np.empty(a_shape), np.empty(b_shape)).shape
+ #print 'aaa' + str(our_result)
+ #print 'bbb' + str(numpy_result)
+ if not np.array_equal(our_result, numpy_result):
+ raise Exception('numpy result not equal to our result')
+ assert(np.array_equal(our_result, numpy_result))
+
+ broadcast_shape_cache[uid] = tuple(our_result)
+ return broadcast_shape_cache[uid]
+
+
+def _broadcast_setup(a, b, wrt):
+ if len(set((a.shape, b.shape))) == 1:
+ asz = a.size
+ IS = np.arange(asz)
+ return IS, IS, asz, asz
+ input_sz = wrt.r.size
+ output_sz = np.broadcast(a.r, b.r).size
+ a2 = np.arange(a.size).reshape(a.shape) if wrt is a else np.zeros(a.shape)
+ b2 = np.arange(b.size).reshape(b.shape) if (wrt is b and wrt is not a) else np.zeros(b.shape)
+ IS = np.arange(output_sz)
+ JS = np.asarray((np.add(a2,b2)).ravel(), np.uint32)
+ return IS, JS, input_sz, output_sz
+
+
+
+class add(Ch):
+ dterms = 'a', 'b'
+
+ def compute_r(self):
+ return self.a.r + self.b.r
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.a and wrt is not self.b:
+ return None
+
+ m = 2. if self.a is self.b else 1.
+ return _broadcast_matrix(self.a, self.b, wrt, m)
+
+
+
+
+class subtract(Ch):
+ dterms = 'a', 'b'
+
+ def compute_r(self):
+ return self.a.r - self.b.r
+
+ def compute_dr_wrt(self, wrt):
+ if (wrt is self.a) == (wrt is self.b):
+ return None
+
+ m = 1. if wrt is self.a else -1.
+ return _broadcast_matrix(self.a, self.b, wrt, m)
+
+
+
+
+
+class power (Ch):
+ """Given vector \f$x\f$, computes \f$x^2\f$ and \f$\frac{dx^2}{x}\f$"""
+ dterms = 'x', 'pow'
+
+ def compute_r(self):
+ return self.safe_power(self.x.r, self.pow.r)
+
+ def compute_dr_wrt(self, wrt):
+
+ if wrt is not self.x and wrt is not self.pow:
+ return None
+
+ x, pow = self.x.r, self.pow.r
+ result = []
+ if wrt is self.x:
+ result.append(pow * self.safe_power(x, pow-1.))
+ if wrt is self.pow:
+ result.append(np.log(x) * self.safe_power(x, pow))
+
+ data = reduce(lambda x, y : x + y, result).ravel()
+
+ return _broadcast_matrix(self.x, self.pow, wrt, data)
+
+
+ def safe_power(self, x, sigma):
+ # This throws a RuntimeWarning sometimes, but then the infs are corrected below
+ result = np.power(x, sigma)
+ result.ravel()[np.isinf(result.ravel())] = 0
+ return result
+
+
+
+
+
+class A_extremum(Ch):
+ """Superclass for various min and max subclasses"""
+ dterms = 'a'
+ terms = 'axis'
+ term_order = 'a', 'axis'
+
+ def f(self, axis): raise NotImplementedError
+ def argf(self, axis): raise NotImplementedError
+
+ def on_changed(self, which):
+ if not hasattr(self, 'axis'):
+ self.axis = None
+
+ def compute_r(self):
+ return self.f(self.a.r, axis=self.axis)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.a:
+
+ mn, stride = self._stride_for_axis(self.axis, self.a.r)
+ JS = np.asarray(np.round(mn + stride * self.argf(self.a.r, axis=self.axis)), dtype=np.uint32).ravel()
+ IS = np.arange(JS.size)
+ data = np.ones(JS.size)
+
+ if self.r.size * wrt.r.size == 1:
+ return data.ravel()[0]
+ return sp.csc_matrix((data, (IS, JS)), shape = (self.r.size, wrt.r.size))
+
+ def _stride_for_axis(self,axis, mtx):
+ if axis is None:
+ mn = np.array([0])
+ stride = np.array([1])
+ else:
+ # TODO: make this less expensive. Shouldn't need to call
+ # np.amin here probably
+ idxs = np.arange(mtx.size).reshape(mtx.shape)
+ mn = np.amin(idxs, axis=axis)
+ mtx_strides = np.array(mtx.strides)
+ stride = mtx_strides / np.min(mtx_strides) # go from bytes to num elements
+ stride = stride[axis]
+ return mn, stride
+
+
+class amax(A_extremum):
+ def f(self, *args, **kwargs): return np.amax(*args, **kwargs)
+ def argf(self, *args, **kwargs): return np.argmax(*args, **kwargs)
+
+max = amax
+
+class amin(A_extremum):
+ def f(self, *args, **kwargs): return np.amin(*args, **kwargs)
+ def argf(self, *args, **kwargs): return np.argmin(*args, **kwargs)
+
+min = amin
+
+class nanmin(A_extremum):
+ def f(self, *args, **kwargs): return np.nanmin(*args, **kwargs)
+ def argf(self, *args, **kwargs): return np.nanargmin(*args, **kwargs)
+
+class nanmax(A_extremum):
+ def f(self, *args, **kwargs): return np.nanmax(*args, **kwargs)
+ def argf(self, *args, **kwargs): return np.nanargmax(*args, **kwargs)
+
+
+class Extremum(Ch):
+ dterms = 'a','b'
+
+ def compute_r(self): return self.f(self.a.r, self.b.r)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.a and wrt is not self.b:
+ return None
+
+ IS, JS, input_sz, output_sz = _broadcast_setup(self.a, self.b, wrt)
+ if wrt is self.a:
+ whichmax = (self.r == self.f(self.a.r, self.b.r-self.f(1,-1))).ravel()
+ else:
+ whichmax = (self.r == self.f(self.b.r, self.a.r-self.f(1,-1))).ravel()
+ IS = IS[whichmax]
+ JS = JS[whichmax]
+ data = np.ones(JS.size)
+
+ return sp.csc_matrix((data, (IS, JS)), shape=(self.r.size, wrt.r.size))
+
+class maximum(Extremum):
+ def f(self, a, b): return np.maximum(a, b)
+
+class minimum(Extremum):
+ def f(self, a, b): return np.minimum(a, b)
+
+
+class multiply(Ch):
+ dterms = 'a', 'b'
+
+ def compute_r(self):
+ return self.a.r * self.b.r
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.a and wrt is not self.b:
+ return None
+
+ a2 = self.a.r if wrt is self.b else np.ones(self.a.shape)
+ b2 = self.b.r if (wrt is self.a and wrt is not self.b) else np.ones(self.b.shape)
+ data = (a2 * b2).ravel()
+
+ if self.a is self.b:
+ data *= 2.
+
+ return _broadcast_matrix(self.a, self.b, wrt, data)
+
+
+
+
+
+class dot(Ch):
+ dterms = 'a', 'b'
+
+ def compute_r(self):
+ return self.a.r.dot(self.b.r)
+
+ def compute_d1(self):
+ # To stay consistent with numpy, we must upgrade 1D arrays to 2D
+ ar = row(self.a.r) if len(self.a.r.shape)<2 else self.a.r.reshape((-1, self.a.r.shape[-1]))
+ br = col(self.b.r) if len(self.b.r.shape)<2 else self.b.r.reshape((self.b.r.shape[0], -1))
+
+ if ar.ndim <= 2:
+ return sp.kron(sp.eye(ar.shape[0], ar.shape[0]),br.T)
+ else:
+ raise NotImplementedError
+
+ def compute_d2(self):
+
+ # To stay consistent with numpy, we must upgrade 1D arrays to 2D
+ ar = row(self.a.r) if len(self.a.r.shape)<2 else self.a.r.reshape((-1, self.a.r.shape[-1]))
+ br = col(self.b.r) if len(self.b.r.shape)<2 else self.b.r.reshape((self.b.r.shape[0], -1))
+
+ if br.ndim <= 1:
+ return self.ar
+ elif br.ndim <= 2:
+ return sp.kron(ar, sp.eye(br.shape[1],br.shape[1]))
+ else:
+ raise NotImplementedError
+
+
+ def compute_dr_wrt(self, wrt):
+
+ if wrt is self.a and wrt is self.b:
+ return self.compute_d1() + self.compute_d2()
+ elif wrt is self.a:
+ return self.compute_d1()
+ elif wrt is self.b:
+ return self.compute_d2()
+
+class BinaryElemwiseNoDrv(Ch):
+ dterms = 'x1', 'x2'
+
+ def compute_r(self):
+ return self._f(self.x1.r, self.x2.r)
+
+ def compute_dr_wrt(self, wrt):
+ return None
+
+class greater(BinaryElemwiseNoDrv):
+ def _f(self, a, b): return np.greater(a,b)
+
+class greater_equal(BinaryElemwiseNoDrv):
+ def _f(self, a, b): return np.greater_equal(a,b)
+
+class less(BinaryElemwiseNoDrv):
+ def _f(self, a, b): return np.less(a,b)
+
+class less_equal(BinaryElemwiseNoDrv):
+ def _f(self, a, b): return np.less_equal(a,b)
+
+class equal(BinaryElemwiseNoDrv):
+ def _f(self, a, b): return np.equal(a,b)
+
+class not_equal(BinaryElemwiseNoDrv):
+ def _f(self, a, b): return np.not_equal(a,b)
+
+def nonzero(a):
+ if hasattr(a, 'compute_r'):
+ a = a.r
+ return np.nonzero(a)
+
+# Pull the code for tensordot in from numpy and reinterpret it using chumpy ops
+import os
+source_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'np_tensordot.py')
+with open(source_path, 'r') as f:
+ source_lines = f.readlines()
+exec(''.join(source_lines))
+__all__ += ['tensordot']
+
+
+
+def main():
+ pass
+
+
+if __name__ == '__main__':
+ main()
+
diff --git a/chumpy-0.70/chumpy/ch_random.py b/chumpy-0.70/chumpy/ch_random.py
new file mode 100644
index 00000000..ccf10f2f
--- /dev/null
+++ b/chumpy-0.70/chumpy/ch_random.py
@@ -0,0 +1,32 @@
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+import numpy.random
+from .ch import Ch
+
+api_not_implemented = ['choice','bytes','shuffle','permutation']
+
+api_wrapped_simple = [
+ # simple random data
+ 'rand','randn','randint','random_integers','random_sample','random','ranf','sample',
+
+ # distributions
+ 'beta','binomial','chisquare','dirichlet','exponential','f','gamma','geometric','gumbel','hypergeometric',
+ 'laplace','logistic','lognormal','logseries','multinomial','multivariate_normal','negative_binomial',
+ 'noncentral_chisquare','noncentral_f','normal','pareto','poisson','power','rayleigh','standard_cauchy',
+ 'standard_exponential','standard_gamma','standard_normal','standard_t','triangular','uniform','vonmises',
+ 'wald','weibull','zipf']
+
+api_wrapped_direct = ['seed', 'get_state', 'set_state']
+
+for rtn in api_wrapped_simple:
+ exec('def %s(*args, **kwargs) : return Ch(numpy.random.%s(*args, **kwargs))' % (rtn, rtn))
+
+for rtn in api_wrapped_direct:
+ exec('%s = numpy.random.%s' % (rtn, rtn))
+
+__all__ = api_wrapped_simple + api_wrapped_direct
+
diff --git a/chumpy-0.70/chumpy/extras.py b/chumpy-0.70/chumpy/extras.py
new file mode 100644
index 00000000..7043f577
--- /dev/null
+++ b/chumpy-0.70/chumpy/extras.py
@@ -0,0 +1,72 @@
+__author__ = 'matt'
+
+from . import ch
+import numpy as np
+from .utils import row, col
+import scipy.sparse as sp
+import scipy.special
+
+class Interp3D(ch.Ch):
+ dterms = 'locations'
+ terms = 'image'
+
+ def on_changed(self, which):
+ if 'image' in which:
+ self.gx, self.gy, self.gz = np.gradient(self.image)
+
+ def compute_r(self):
+ locations = self.locations.r.copy()
+ for i in range(3):
+ locations[:,i] = np.clip(locations[:,i], 0, self.image.shape[i]-1)
+ locs = np.floor(locations).astype(np.uint32)
+ result = self.image[locs[:,0], locs[:,1], locs[:,2]]
+ offset = (locations - locs)
+ dr = self.dr_wrt(self.locations).dot(offset.ravel())
+ return result + dr
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.locations:
+ locations = self.locations.r.copy()
+ for i in range(3):
+ locations[:,i] = np.clip(locations[:,i], 0, self.image.shape[i]-1)
+ locations = locations.astype(np.uint32)
+
+ xc = col(self.gx[locations[:,0], locations[:,1], locations[:,2]])
+ yc = col(self.gy[locations[:,0], locations[:,1], locations[:,2]])
+ zc = col(self.gz[locations[:,0], locations[:,1], locations[:,2]])
+
+ data = np.vstack([xc.ravel(), yc.ravel(), zc.ravel()]).T.copy()
+ JS = np.arange(locations.size)
+ IS = JS // 3
+
+ return sp.csc_matrix((data.ravel(), (IS, JS)))
+
+
+class gamma(ch.Ch):
+ dterms = 'x',
+
+ def compute_r(self):
+ return scipy.special.gamma(self.x.r)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.x:
+ d = scipy.special.polygamma(0, self.x.r)*self.r
+ return sp.diags([d.ravel()], [0])
+
+# This function is based directly on the "moment" function
+# in scipy, specifically in mstats_basic.py.
+def moment(a, moment=1, axis=0):
+ if moment == 1:
+ # By definition the first moment about the mean is 0.
+ shape = list(a.shape)
+ del shape[axis]
+ if shape:
+ # return an actual array of the appropriate shape
+ return ch.zeros(shape, dtype=float)
+ else:
+ # the input was 1D, so return a scalar instead of a rank-0 array
+ return np.float64(0.0)
+ else:
+ mn = ch.expand_dims(a.mean(axis=axis), axis)
+ s = ch.power((a-mn), moment)
+ return s.mean(axis=axis)
diff --git a/chumpy-0.70/chumpy/linalg.py b/chumpy-0.70/chumpy/linalg.py
new file mode 100755
index 00000000..9838ebd2
--- /dev/null
+++ b/chumpy-0.70/chumpy/linalg.py
@@ -0,0 +1,306 @@
+#!/usr/bin/env python
+# encoding: utf-8
+
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+
+__all__ = ['inv', 'svd', 'det', 'slogdet', 'pinv', 'lstsq', 'norm']
+
+import numpy as np
+import scipy.sparse as sp
+from .ch import Ch, depends_on
+from .ch_ops import NanDivide
+from .ch_ops import asarray as ch_asarray
+from .ch_ops import sqrt as ch_sqrt
+from .ch_ops import sum as ch_sum
+from .reordering import concatenate as ch_concatenate
+from .ch_random import randn as ch_random_randn
+from .utils import row, col
+
+
+try:
+ asarray = ch_asarray
+ import inspect
+ exec(''.join(inspect.getsourcelines(np.linalg.tensorinv)[0]))
+ __all__.append('tensorinv')
+except: pass
+
+def norm(x, ord=None, axis=None):
+ if ord is not None or axis is not None:
+ raise NotImplementedError("'ord' and 'axis' should be None for now.")
+
+ return ch_sqrt(ch_sum(x**2))
+
+# This version works but derivatives are too slow b/c of nested loop in Svd implementation.
+# def lstsq(a, b):
+# u, s, v = Svd(a)
+# x = (v.T / s).dot(u.T.dot(b))
+# residuals = NotImplementedError # ch_sum((a.dot(x) - b)**2, axis=0)
+# rank = NotImplementedError
+# s = NotImplementedError
+# return x, residuals, rank, s
+
+def lstsq(a, b, rcond=-1):
+ if rcond != -1:
+ raise Exception('non-default rcond not yet implemented')
+
+ x = Ch(lambda a, b : pinv(a).dot(b))
+ x.a = a
+ x.b = b
+ residuals = ch_sum( (x.a.dot(x) - x.b) **2 , axis=0)
+ rank = NotImplementedError
+ s = NotImplementedError
+
+ return x, residuals, rank, s
+
+def Svd(x, full_matrices=0, compute_uv=1):
+
+ if full_matrices != 0:
+ raise Exception('full_matrices must be 0')
+ if compute_uv != 1:
+ raise Exception('compute_uv must be 1')
+
+ need_transpose = x.shape[0] < x.shape[1]
+
+ if need_transpose:
+ x = x.T
+
+ svd_d = SvdD(x=x)
+ svd_v = SvdV(x=x, svd_d=svd_d)
+ svd_u = SvdU(x=x, svd_d=svd_d, svd_v=svd_v)
+
+ if need_transpose:
+ return svd_v, svd_d, svd_u.T
+ else:
+ return svd_u, svd_d, svd_v.T
+
+
+class Pinv(Ch):
+ dterms = 'mtx'
+
+ def on_changed(self, which):
+ mtx = self.mtx
+ if mtx.shape[1] > mtx.shape[0]:
+ result = mtx.T.dot(Inv(mtx.dot(mtx.T)))
+ else:
+ result = Inv(mtx.T.dot(mtx)).dot(mtx.T)
+ self._result = result
+
+ def compute_r(self):
+ return self._result.r
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.mtx:
+ return self._result.dr_wrt(self.mtx)
+
+# Couldn't make the SVD version of pinv work yet...
+#
+# class Pinv(Ch):
+# dterms = 'mtx'
+#
+# def on_changed(self, which):
+# u, s, v = Svd(self.mtx)
+# result = (v.T * (NanDivide(1.,row(s)))).dot(u.T)
+# self.add_dterm('_result', result)
+#
+# def compute_r(self):
+# return self._result.r
+#
+# def compute_dr_wrt(self, wrt):
+# if wrt is self._result:
+# return 1
+
+
+
+class LogAbsDet(Ch):
+ dterms = 'x'
+
+ def on_changed(self, which):
+ self.sign, self.slogdet = np.linalg.slogdet(self.x.r)
+
+ def compute_r(self):
+ return self.slogdet
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.x:
+ return row(np.linalg.inv(self.x.r).T)
+
+class SignLogAbsDet(Ch):
+ dterms = 'logabsdet',
+
+ def compute_r(self):
+ _ = self.logabsdet.r
+ return self.logabsdet.sign
+
+ def compute_dr_wrt(self, wrt):
+ return None
+
+
+class Det(Ch):
+ dterms = 'x'
+
+ def compute_r(self):
+ return np.linalg.det(self.x.r)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.x:
+ return row(self.r * np.linalg.inv(self.x.r).T)
+
+
+class Inv(Ch):
+ dterms = 'a'
+
+ def compute_r(self):
+ return np.linalg.inv(self.a.r)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.a:
+ return None
+
+ Ainv = self.r
+
+ if Ainv.ndim <= 2:
+ return -np.kron(Ainv, Ainv.T)
+ else:
+ Ainv = np.reshape(Ainv, (-1, Ainv.shape[-2], Ainv.shape[-1]))
+ AinvT = np.rollaxis(Ainv, -1, -2)
+ AinvT = np.reshape(AinvT, (-1, AinvT.shape[-2], AinvT.shape[-1]))
+ result = np.dstack([-np.kron(Ainv[i], AinvT[i]).T for i in range(Ainv.shape[0])]).T
+ result = sp.block_diag(result)
+
+ return result
+
+
+class SvdD(Ch):
+ dterms = 'x'
+
+ @depends_on('x')
+ def UDV(self):
+ result = np.linalg.svd(self.x.r, full_matrices=False)
+ result = [result[0], result[1], result[2].T]
+ result[1][np.abs(result[1]) < np.spacing(1)] = 0.
+ return result
+
+ def compute_r(self):
+ return self.UDV[1]
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.x:
+ return
+
+ u, d, v = self.UDV
+ shp = self.x.r.shape
+ u = u[:shp[0], :shp[1]]
+ v = v[:shp[1], :d.size]
+
+ result = np.einsum('ik,jk->kij', u, v)
+ result = result.reshape((result.shape[0], -1))
+ return result
+
+
+class SvdV(Ch):
+ terms = 'svd_d'
+ dterms = 'x'
+
+ def compute_r(self):
+ return self.svd_d.UDV[2]
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.x:
+ return
+
+ U,_D,V = self.svd_d.UDV
+
+ shp = self.svd_d.x.r.shape
+ mxsz = max(shp[0], shp[1])
+ #mnsz = min(shp[0], shp[1])
+ D = np.zeros(mxsz)
+ D[:_D.size] = _D
+
+ omega = np.zeros((shp[0], shp[1], shp[1], shp[1]))
+
+ M = shp[0]
+ N = shp[1]
+
+ assert(M >= N)
+
+ for i in range(shp[0]):
+ for j in range(shp[1]):
+ for k in range(N):
+ for l in range(k+1, N):
+ mtx = np.array([
+ [D[l],D[k]],
+ [D[k],D[l]]])
+
+ rhs = np.array([U[i,k]*V[j,l], -U[i,l]*V[j,k]])
+ result = np.linalg.solve(mtx, rhs)
+
+ omega[i,j,k,l] = result[1]
+ omega[i,j,l,k] = -result[1]
+
+ #print 'v size is %s' % (str(V.shape),)
+ #print 'v omega size is %s' % (str(omega.shape),)
+ assert(V.shape[1] == omega.shape[2])
+ return np.einsum('ak,ijkl->alij', -V, omega).reshape((self.r.size, wrt.r.size))
+
+
+class SvdU(Ch):
+ dterms = 'x'
+ terms = 'svd_d', 'svd_v'
+
+ def compute_r(self):
+ return self.svd_d.UDV[0]
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.x:
+ # return (
+ # self.svd_d.x.dot(self.svd_v)
+ # /
+ # self.svd_d.reshape((1,-1))
+ # ).dr_wrt(self.svd_d.x)
+ return (
+ NanDivide(
+ self.svd_d.x.dot(self.svd_v),
+ self.svd_d.reshape((1,-1)))
+ ).dr_wrt(self.svd_d.x)
+
+
+inv = Inv
+svd = Svd
+det = Det
+pinv = Pinv
+
+def slogdet(*args):
+ n = len(args)
+ if n == 1:
+ r2 = LogAbsDet(x=args[0])
+ r1 = SignLogAbsDet(r2)
+ return r1, r2
+ else:
+ r2 = [LogAbsDet(x=arg) for arg in args]
+ r1 = [SignLogAbsDet(r) for r in r2]
+ r2 = ch_concatenate(r2)
+ return r1, r2
+
+def main():
+
+ tmp = ch_random_randn(100).reshape((10,10))
+ print('chumpy version: ' + str(slogdet(tmp)[1].r))
+ print('old version:' + str(np.linalg.slogdet(tmp.r)[1]))
+
+ eps = 1e-10
+ diff = np.random.rand(100) * eps
+ diff_reshaped = diff.reshape((10,10))
+ print(np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1])
+ print(slogdet(tmp)[1].dr_wrt(tmp).dot(diff))
+
+ print(np.linalg.slogdet(tmp.r)[0])
+ print(slogdet(tmp)[0])
+
+if __name__ == '__main__':
+ main()
+
diff --git a/chumpy-0.70/chumpy/logic.py b/chumpy-0.70/chumpy/logic.py
new file mode 100644
index 00000000..604b9a80
--- /dev/null
+++ b/chumpy-0.70/chumpy/logic.py
@@ -0,0 +1,39 @@
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+__author__ = 'matt'
+
+
+__all__ = [] # added to incrementally below
+
+from . import ch
+from .ch import Ch
+import numpy as np
+
+class LogicFunc(Ch):
+ dterms = 'a' # we keep this here so that changes to children of "a" will trigger cache changes
+ terms = 'args', 'kwargs', 'funcname'
+
+ def compute_r(self):
+ arr = self.a
+ fn = getattr(np, self.funcname)
+ return fn(arr, *self.args, **self.kwargs)
+
+ def compute_dr_wrt(self, wrt):
+ pass
+
+
+unaries = 'all', 'any', 'isfinite', 'isinf', 'isnan', 'isneginf', 'isposinf', 'logical_not'
+for unary in unaries:
+ exec("def %s(a, *args, **kwargs): return LogicFunc(a=a, args=args, kwargs=kwargs, funcname='%s')" % (unary, unary))
+__all__ += unaries
+
+
+
+if __name__ == '__main__':
+ from . import ch
+ print(all(np.array([1,2,3])))
+ print(isinf(np.array([0,2,3])))
diff --git a/chumpy-0.70/chumpy/monitor.py b/chumpy-0.70/chumpy/monitor.py
new file mode 100644
index 00000000..25f278b1
--- /dev/null
+++ b/chumpy-0.70/chumpy/monitor.py
@@ -0,0 +1,149 @@
+'''
+Logging service for tracking dr tree changes from root objective
+and record every step that incrementally changes the dr tree
+
+'''
+import os, sys, time
+import json
+import psutil
+
+import scipy.sparse as sp
+import numpy as np
+from . import reordering
+
+_TWO_20 = float(2 **20)
+
+'''
+memory utils
+
+'''
+def pdb_mem():
+ from .monitor import get_current_memory
+ mem = get_current_memory()
+ if mem > 7000:
+ import pdb;pdb.set_trace()
+
+def get_peak_mem():
+ '''
+ this returns peak memory use since process starts till the moment its called
+ '''
+ import resource
+ rusage_denom = 1024.
+ if sys.platform == 'darwin':
+ # ... it seems that in OSX the output is different units ...
+ rusage_denom = rusage_denom * rusage_denom
+ mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / rusage_denom
+ return mem
+
+def get_current_memory():
+ p = psutil.Process(os.getpid())
+ mem = p.memory_info()[0]/_TWO_20
+
+ return mem
+
+'''
+Helper for Profiler
+'''
+
+def build_cache_info(k, v, info_dict):
+ if v is not None:
+ issparse = sp.issparse(v)
+ size = v.size
+ if issparse:
+ nonzero = len(v.data)
+ else:
+ nonzero = np.count_nonzero(v)
+ info_dict[k.short_name] = {
+ 'sparse': issparse,
+ 'size' : str(size),
+ 'nonzero' : nonzero,
+ }
+
+
+def cache_info(ch_node):
+ result = {}
+ if isinstance(ch_node, reordering.Concatenate) and hasattr(ch_node, 'dr_cached') and len(ch_node.dr_cached) > 0:
+ for k, v in ch_node.dr_cached.items():
+ build_cache_info(k, v, result)
+ elif len(ch_node._cache['drs']) > 0:
+ for k, v in ch_node._cache['drs'].items():
+ build_cache_info(k, v, result)
+
+ return result
+
+class DrWrtProfiler(object):
+ base_path = os.path.abspath('profiles')
+
+ def __init__(self, root, base_path=None):
+ self.root = root.obj
+ self.history = []
+
+ ts = time.time()
+ if base_path:
+ self.base_path = base_path
+
+ self.path = os.path.join(self.base_path, 'profile_%s.json' % str(ts))
+ self.root_path = os.path.join(self.base_path, 'root_%s.json' % str(ts))
+
+
+ with open(self.root_path, 'w') as f:
+ json.dump(self.dump_tree(self.root), f, indent=4)
+
+ def dump_tree(self, node):
+ if not hasattr(node, 'dterms'):
+ return []
+
+ node_dict = self.serialize_node(node, verbose=False)
+ if hasattr(node, 'visited') and node.visited:
+ node_dict.update({'indirect':True})
+ return node_dict
+
+ node.visited = True
+ children_list = []
+ for dterm in node.dterms:
+ if hasattr(node, dterm):
+ child = getattr(node, dterm)
+ if hasattr(child, 'dterms') or hasattr(child, 'terms'):
+ children_list.append(self.dump_tree(child))
+ node_dict.update({'children':children_list})
+ return node_dict
+
+ def serialize_node(self, ch_node, verbose=True):
+ node_id = id(ch_node)
+ name = ch_node.short_name
+ ts = time.time()
+ status = ch_node._status
+ mem = get_current_memory()
+ node_cache_info = cache_info(ch_node)
+
+ rec = {
+ 'id': str(node_id),
+ 'indirect' : False,
+ }
+ if verbose:
+ rec.update({
+ 'name':name,
+ 'ts' : ts,
+ 'status':status,
+ 'mem': mem,
+ 'cache': node_cache_info,
+ })
+ return rec
+
+ def show_tree(self, label):
+ '''
+ show tree from the root node
+ '''
+ self.root.show_tree_cache(label)
+
+ def record(self, ch_node):
+ '''
+ Incremental changes
+ '''
+ rec = self.serialize_node(ch_node)
+ self.history.append(rec)
+
+ def harvest(self):
+ print('collecting and dump to file %s' % self.path)
+ with open(self.path, 'w') as f:
+ json.dump(self.history, f, indent=4)
\ No newline at end of file
diff --git a/chumpy-0.70/chumpy/np_tensordot.py b/chumpy-0.70/chumpy/np_tensordot.py
new file mode 100644
index 00000000..1b966bb7
--- /dev/null
+++ b/chumpy-0.70/chumpy/np_tensordot.py
@@ -0,0 +1,228 @@
+# Up to numpy 1.13, the numpy implementation of tensordot could be
+# reinterpreted using chumpy. With numpy 1.14 the implementation started using
+# ufunc.multiply.reduce which can't be understood by chumpy. This is the
+# chumpy-compatible implementation of tensodrot from numpy 1.13.3.
+#
+# i.e.
+#
+# import inspect
+# with open('np_tensordot.py', 'w') as f:
+# f.write(''.join(inspect.getsourcelines(np.tensordot)[0]))
+
+"""
+Copyright (c) 2005-2017, NumPy Developers.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ disclaimer in the documentation and/or other materials provided
+ with the distribution.
+
+ * Neither the name of the NumPy Developers nor the names of any
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+def tensordot(a, b, axes=2):
+ """
+ Compute tensor dot product along specified axes for arrays >= 1-D.
+
+ Given two tensors (arrays of dimension greater than or equal to one),
+ `a` and `b`, and an array_like object containing two array_like
+ objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
+ elements (components) over the axes specified by ``a_axes`` and
+ ``b_axes``. The third argument can be a single non-negative
+ integer_like scalar, ``N``; if it is such, then the last ``N``
+ dimensions of `a` and the first ``N`` dimensions of `b` are summed
+ over.
+
+ Parameters
+ ----------
+ a, b : array_like, len(shape) >= 1
+ Tensors to "dot".
+
+ axes : int or (2,) array_like
+ * integer_like
+ If an int N, sum over the last N axes of `a` and the first N axes
+ of `b` in order. The sizes of the corresponding axes must match.
+ * (2,) array_like
+ Or, a list of axes to be summed over, first sequence applying to `a`,
+ second to `b`. Both elements array_like must be of the same length.
+
+ See Also
+ --------
+ dot, einsum
+
+ Notes
+ -----
+ Three common use cases are:
+ * ``axes = 0`` : tensor product :math:`a\\otimes b`
+ * ``axes = 1`` : tensor dot product :math:`a\\cdot b`
+ * ``axes = 2`` : (default) tensor double contraction :math:`a:b`
+
+ When `axes` is integer_like, the sequence for evaluation will be: first
+ the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
+ Nth axis in `b` last.
+
+ When there is more than one axis to sum over - and they are not the last
+ (first) axes of `a` (`b`) - the argument `axes` should consist of
+ two sequences of the same length, with the first axis to sum over given
+ first in both sequences, the second axis second, and so forth.
+
+ Examples
+ --------
+ A "traditional" example:
+
+ >>> a = np.arange(60.).reshape(3,4,5)
+ >>> b = np.arange(24.).reshape(4,3,2)
+ >>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
+ >>> c.shape
+ (5, 2)
+ >>> c
+ array([[ 4400., 4730.],
+ [ 4532., 4874.],
+ [ 4664., 5018.],
+ [ 4796., 5162.],
+ [ 4928., 5306.]])
+ >>> # A slower but equivalent way of computing the same...
+ >>> d = np.zeros((5,2))
+ >>> for i in range(5):
+ ... for j in range(2):
+ ... for k in range(3):
+ ... for n in range(4):
+ ... d[i,j] += a[k,n,i] * b[n,k,j]
+ >>> c == d
+ array([[ True, True],
+ [ True, True],
+ [ True, True],
+ [ True, True],
+ [ True, True]], dtype=bool)
+
+ An extended example taking advantage of the overloading of + and \\*:
+
+ >>> a = np.array(range(1, 9))
+ >>> a.shape = (2, 2, 2)
+ >>> A = np.array(('a', 'b', 'c', 'd'), dtype=object)
+ >>> A.shape = (2, 2)
+ >>> a; A
+ array([[[1, 2],
+ [3, 4]],
+ [[5, 6],
+ [7, 8]]])
+ array([[a, b],
+ [c, d]], dtype=object)
+
+ >>> np.tensordot(a, A) # third argument default is 2 for double-contraction
+ array([abbcccdddd, aaaaabbbbbbcccccccdddddddd], dtype=object)
+
+ >>> np.tensordot(a, A, 1)
+ array([[[acc, bdd],
+ [aaacccc, bbbdddd]],
+ [[aaaaacccccc, bbbbbdddddd],
+ [aaaaaaacccccccc, bbbbbbbdddddddd]]], dtype=object)
+
+ >>> np.tensordot(a, A, 0) # tensor product (result too long to incl.)
+ array([[[[[a, b],
+ [c, d]],
+ ...
+
+ >>> np.tensordot(a, A, (0, 1))
+ array([[[abbbbb, cddddd],
+ [aabbbbbb, ccdddddd]],
+ [[aaabbbbbbb, cccddddddd],
+ [aaaabbbbbbbb, ccccdddddddd]]], dtype=object)
+
+ >>> np.tensordot(a, A, (2, 1))
+ array([[[abb, cdd],
+ [aaabbbb, cccdddd]],
+ [[aaaaabbbbbb, cccccdddddd],
+ [aaaaaaabbbbbbbb, cccccccdddddddd]]], dtype=object)
+
+ >>> np.tensordot(a, A, ((0, 1), (0, 1)))
+ array([abbbcccccddddddd, aabbbbccccccdddddddd], dtype=object)
+
+ >>> np.tensordot(a, A, ((2, 1), (1, 0)))
+ array([acccbbdddd, aaaaacccccccbbbbbbdddddddd], dtype=object)
+
+ """
+ try:
+ iter(axes)
+ except:
+ axes_a = list(range(-axes, 0))
+ axes_b = list(range(0, axes))
+ else:
+ axes_a, axes_b = axes
+ try:
+ na = len(axes_a)
+ axes_a = list(axes_a)
+ except TypeError:
+ axes_a = [axes_a]
+ na = 1
+ try:
+ nb = len(axes_b)
+ axes_b = list(axes_b)
+ except TypeError:
+ axes_b = [axes_b]
+ nb = 1
+
+ a, b = asarray(a), asarray(b)
+ as_ = a.shape
+ nda = a.ndim
+ bs = b.shape
+ ndb = b.ndim
+ equal = True
+ if na != nb:
+ equal = False
+ else:
+ for k in range(na):
+ if as_[axes_a[k]] != bs[axes_b[k]]:
+ equal = False
+ break
+ if axes_a[k] < 0:
+ axes_a[k] += nda
+ if axes_b[k] < 0:
+ axes_b[k] += ndb
+ if not equal:
+ raise ValueError("shape-mismatch for sum")
+
+ # Move the axes to sum over to the end of "a"
+ # and to the front of "b"
+ notin = [k for k in range(nda) if k not in axes_a]
+ newaxes_a = notin + axes_a
+ N2 = 1
+ for axis in axes_a:
+ N2 *= as_[axis]
+ newshape_a = (-1, N2)
+ olda = [as_[axis] for axis in notin]
+
+ notin = [k for k in range(ndb) if k not in axes_b]
+ newaxes_b = axes_b + notin
+ N2 = 1
+ for axis in axes_b:
+ N2 *= bs[axis]
+ newshape_b = (N2, -1)
+ oldb = [bs[axis] for axis in notin]
+
+ at = a.transpose(newaxes_a).reshape(newshape_a)
+ bt = b.transpose(newaxes_b).reshape(newshape_b)
+ res = dot(at, bt)
+ return res.reshape(olda + oldb)
diff --git a/chumpy-0.70/chumpy/optimization.py b/chumpy-0.70/chumpy/optimization.py
new file mode 100755
index 00000000..4f68b65e
--- /dev/null
+++ b/chumpy-0.70/chumpy/optimization.py
@@ -0,0 +1,161 @@
+#!/usr/bin/env python
+
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+__all__ = ['minimize']
+
+import numpy as np
+from . import ch
+import scipy.sparse as sp
+import scipy.optimize
+
+from .optimization_internal import minimize_dogleg
+
+#from memory_profiler import profile, memory_usage
+
+# def disable_cache_for_single_parent_node(node):
+# if hasattr(node, '_parents') and len(node._parents.keys()) == 1:
+# node.want_cache = False
+
+
+# Nelder-Mead
+# Powell
+# CG
+# BFGS
+# Newton-CG
+# Anneal
+# L-BFGS-B
+# TNC
+# COBYLA
+# SLSQP
+# dogleg
+# trust-ncg
+def minimize(fun, x0, method='dogleg', bounds=None, constraints=(), tol=None, callback=None, options=None):
+
+ if method == 'dogleg':
+ if options is None: options = {}
+ return minimize_dogleg(fun, free_variables=x0, on_step=callback, **options)
+
+ if isinstance(fun, list) or isinstance(fun, tuple):
+ fun = ch.concatenate([f.ravel() for f in fun])
+ if isinstance(fun, dict):
+ fun = ch.concatenate([f.ravel() for f in list(fun.values())])
+ obj = fun
+ free_variables = x0
+
+
+ from .ch import SumOfSquares
+
+ hessp = None
+ hess = None
+ if obj.size == 1:
+ obj_scalar = obj
+ else:
+ obj_scalar = SumOfSquares(obj)
+
+ def hessp(vs, p,obj, obj_scalar, free_variables):
+ changevars(vs,obj,obj_scalar,free_variables)
+ if not hasattr(hessp, 'vs'):
+ hessp.vs = vs*0+1e16
+ if np.max(np.abs(vs-hessp.vs)) > 0:
+
+ J = ns_jacfunc(vs,obj,obj_scalar,free_variables)
+ hessp.J = J
+ hessp.H = 2. * J.T.dot(J)
+ hessp.vs = vs
+ return np.array(hessp.H.dot(p)).ravel()
+ #return 2*np.array(hessp.J.T.dot(hessp.J.dot(p))).ravel()
+
+ if method.lower() != 'newton-cg':
+ def hess(vs, obj, obj_scalar, free_variables):
+ changevars(vs,obj,obj_scalar,free_variables)
+ if not hasattr(hessp, 'vs'):
+ hessp.vs = vs*0+1e16
+ if np.max(np.abs(vs-hessp.vs)) > 0:
+ J = ns_jacfunc(vs,obj,obj_scalar,free_variables)
+ hessp.H = 2. * J.T.dot(J)
+ return hessp.H
+
+ def changevars(vs, obj, obj_scalar, free_variables):
+ cur = 0
+ changed = False
+ for idx, freevar in enumerate(free_variables):
+ sz = freevar.r.size
+ newvals = vs[cur:cur+sz].copy().reshape(free_variables[idx].shape)
+ if np.max(np.abs(newvals-free_variables[idx]).ravel()) > 0:
+ free_variables[idx][:] = newvals
+ changed = True
+
+ cur += sz
+
+ methods_without_callback = ('anneal', 'powell', 'cobyla', 'slsqp')
+ if callback is not None and changed and method.lower() in methods_without_callback:
+ callback(None)
+
+ return changed
+
+ def residuals(vs,obj, obj_scalar, free_variables):
+ changevars(vs, obj, obj_scalar, free_variables)
+ residuals = obj_scalar.r.ravel()[0]
+ return residuals
+
+ def scalar_jacfunc(vs,obj, obj_scalar, free_variables):
+ if not hasattr(scalar_jacfunc, 'vs'):
+ scalar_jacfunc.vs = vs*0+1e16
+ if np.max(np.abs(vs-scalar_jacfunc.vs)) == 0:
+ return scalar_jacfunc.J
+
+ changevars(vs, obj, obj_scalar, free_variables)
+
+ if True: # faster, at least on some problems
+ result = np.concatenate([np.array(obj_scalar.lop(wrt, np.array([[1]]))).ravel() for wrt in free_variables])
+ else:
+ jacs = [obj_scalar.dr_wrt(wrt) for wrt in free_variables]
+ for idx, jac in enumerate(jacs):
+ if sp.issparse(jac):
+ jacs[idx] = jacs[idx].todense()
+ result = np.concatenate([jac.ravel() for jac in jacs])
+
+ scalar_jacfunc.J = result
+ scalar_jacfunc.vs = vs
+ return result.ravel()
+
+ def ns_jacfunc(vs,obj, obj_scalar, free_variables):
+ if not hasattr(ns_jacfunc, 'vs'):
+ ns_jacfunc.vs = vs*0+1e16
+ if np.max(np.abs(vs-ns_jacfunc.vs)) == 0:
+ return ns_jacfunc.J
+
+ changevars(vs, obj, obj_scalar, free_variables)
+ jacs = [obj.dr_wrt(wrt) for wrt in free_variables]
+ result = hstack(jacs)
+
+ ns_jacfunc.J = result
+ ns_jacfunc.vs = vs
+ return result
+
+
+ x1 = scipy.optimize.minimize(
+ method=method,
+ fun=residuals,
+ callback=callback,
+ x0=np.concatenate([free_variable.r.ravel() for free_variable in free_variables]),
+ jac=scalar_jacfunc,
+ hessp=hessp, hess=hess, args=(obj, obj_scalar, free_variables),
+ bounds=bounds, constraints=constraints, tol=tol, options=options).x
+
+ changevars(x1, obj, obj_scalar, free_variables)
+ return free_variables
+
+
+def main():
+ pass
+
+
+if __name__ == '__main__':
+ main()
+
diff --git a/chumpy-0.70/chumpy/optimization_internal.py b/chumpy-0.70/chumpy/optimization_internal.py
new file mode 100644
index 00000000..10436338
--- /dev/null
+++ b/chumpy-0.70/chumpy/optimization_internal.py
@@ -0,0 +1,455 @@
+import sys
+import warnings
+import numpy as np
+import scipy.sparse as sp
+from . import ch, utils
+from .ch import pif
+from .utils import timer
+
+
+def clear_cache_single(node):
+ node._cache['drs'].clear()
+ if hasattr(node, 'dr_cached'):
+ node.dr_cached.clear()
+
+def vstack(x):
+ x = [a if not isinstance(a, sp.linalg.interface.LinearOperator) else a.dot(np.eye(a.shape[1])) for a in x]
+ return sp.vstack(x, format='csc') if any([sp.issparse(a) for a in x]) else np.vstack(x)
+def hstack(x):
+ x = [a if not isinstance(a, sp.linalg.interface.LinearOperator) else a.dot(np.eye(a.shape[1])) for a in x]
+ return sp.hstack(x, format='csc') if any([sp.issparse(a) for a in x]) else np.hstack(x)
+
+
+_giter = 0
+class ChInputsStacked(ch.Ch):
+ dterms = 'x', 'obj'
+ terms = 'free_variables'
+
+ def compute_r(self):
+ if not hasattr(self, 'fevals'):
+ self.fevals = 0
+ self.fevals += 1
+ return self.obj.r.ravel()
+
+ def dr_wrt(self, wrt, profiler=None):
+ '''
+ Loop over free variables and delete cache for the whole tree after finished each one
+ '''
+ if wrt is self.x:
+ jacs = []
+ for fvi, freevar in enumerate(self.free_variables):
+ tm = timer()
+ if isinstance(freevar, ch.Select):
+ new_jac = self.obj.dr_wrt(freevar.a, profiler=profiler)
+ try:
+ new_jac = new_jac[:, freevar.idxs]
+ except:
+ # non-csc sparse matrices may not support column-wise indexing
+ new_jac = new_jac.tocsc()[:, freevar.idxs]
+ else:
+ new_jac = self.obj.dr_wrt(freevar, profiler=profiler)
+
+ pif('dx wrt {} in {}sec, sparse: {}'.format(freevar.short_name, tm(), sp.issparse(new_jac)))
+
+ if self._make_dense and sp.issparse(new_jac):
+ new_jac = new_jac.todense()
+ if self._make_sparse and not sp.issparse(new_jac):
+ new_jac = sp.csc_matrix(new_jac)
+
+ if new_jac is None:
+ raise Exception(
+ 'Objective has no derivative wrt free variable {}. '
+ 'You should likely remove it.'.format(fvi))
+
+ jacs.append(new_jac)
+ tm = timer()
+ utils.dfs_do_func_on_graph(self.obj, clear_cache_single)
+ pif('dfs_do_func_on_graph in {}sec'.format(tm()))
+ tm = timer()
+ J = hstack(jacs)
+ pif('hstack in {}sec'.format(tm()))
+ return J
+
+ def on_changed(self, which):
+ global _giter
+ _giter += 1
+ if 'x' in which:
+ pos = 0
+ for idx, freevar in enumerate(self.free_variables):
+ sz = freevar.r.size
+ rng = np.arange(pos, pos+sz)
+ if isinstance(self.free_variables[idx], ch.Select):
+ # Deal with nested selects
+ selects = []
+ a = self.free_variables[idx]
+ while isinstance(a, ch.Select):
+ selects.append(a.idxs)
+ a = a.a
+ newv = a.x.copy()
+ idxs = selects.pop()
+ while len(selects) > 0:
+ idxs = idxs[selects.pop()]
+ newv.ravel()[idxs] = self.x.r.ravel()[rng]
+ a.__setattr__('x', newv, _giter)
+ elif isinstance(self.free_variables[idx].x, np.ndarray):
+ self.free_variables[idx].__setattr__('x', self.x.r[rng].copy().reshape(self.free_variables[idx].x.shape), _giter)
+ else: # a number
+ self.free_variables[idx].__setattr__('x', self.x.r[rng], _giter)
+ pos += sz
+
+ @property
+ def J(self):
+ '''
+ Compute Jacobian. Analyze dr graph first to disable unnecessary caching
+ '''
+ result = self.dr_wrt(self.x, profiler=self.profiler).copy()
+ if self.profiler:
+ self.profiler.harvest()
+ return np.atleast_2d(result) if not sp.issparse(result) else result
+
+
+def setup_sparse_solver(sparse_solver):
+ _solver_fns = {
+ 'cg': lambda A, x, M=None : sp.linalg.cg(A, x, M=M, tol=1e-10)[0],
+ 'spsolve': lambda A, x : sp.linalg.spsolve(A, x)
+ }
+ if callable(sparse_solver):
+ return sparse_solver
+ elif isinstance(sparse_solver, str) and sparse_solver in list(_solver_fns.keys()):
+ return _solver_fns[sparse_solver]
+ else:
+ raise Exception('sparse_solver argument must be either a string in the set (%s) or have the api of scipy.sparse.linalg.spsolve.' % ', '.join(list(_solver_fns.keys())))
+
+
+def setup_objective(obj, free_variables, on_step=None, disp=True, make_dense=False):
+ '''
+ obj here can be a list of ch objects or a dict of label: ch objects. Either way, the ch
+ objects will be merged into one objective using a ChInputsStacked. The labels are just used
+ for printing out values per objective with each iteration. If make_dense is True, the
+ resulting object with return a desne Jacobian
+ '''
+ # Validate free variables
+ num_unique_ids = len(np.unique(np.array([id(freevar) for freevar in free_variables])))
+ if num_unique_ids != len(free_variables):
+ raise Exception('The "free_variables" param contains duplicate variables.')
+ # Extract labels
+ labels = {}
+ if isinstance(obj, list) or isinstance(obj, tuple):
+ obj = ch.concatenate([f.ravel() for f in obj])
+ elif isinstance(obj, dict):
+ labels = obj
+ obj = ch.concatenate([f.ravel() for f in list(obj.values())])
+ # build objective
+ x = np.concatenate([freevar.r.ravel() for freevar in free_variables])
+ obj = ChInputsStacked(obj=obj, free_variables=free_variables, x=x, make_dense=make_dense)
+ # build callback
+ def callback():
+ if on_step is not None:
+ on_step(obj)
+ if disp:
+ report_line = ['%.2e' % (np.sum(obj.r**2),)]
+ for label, objective in sorted(list(labels.items()), key=lambda x: x[0]):
+ report_line.append('%s: %.2e' % (label, np.sum(objective.r**2)))
+ report_line = " | ".join(report_line) + '\n'
+ sys.stderr.write(report_line)
+ return obj, callback
+
+
+class DoglegState(object):
+ '''
+ Dogleg preserves a great deal of state from iteration to iteration. Many of the things
+ that we need to calculate are dependent only on this state (e.g. the various trust region
+ steps, the current jacobian and the A & g that depends on it, etc.). Holding the state and
+ the various methods based on that state here allows us to seperate a lot of the jacobian
+ based calculation from the flow control of the optmization.
+
+ There will be once instance of DoglegState per invocation of minimize_dogleg.
+ '''
+ def __init__(self, delta, solve):
+ self.iteration = 0
+ self._d_gn = None # gauss-newton
+ self._d_sd = None # steepest descent
+ self._d_dl = None # dogleg
+ self.J = None
+ self.A = None
+ self.g = None
+ self._p = None
+ self.delta = delta
+ self.solve = solve
+ self._r = None
+ self.rho = None
+ self.done = False
+
+ @property
+ def p(self):
+ '''p is the current proposed input vector'''
+ return self._p
+ @p.setter
+ def p(self, val):
+ self._p = val.reshape((-1, 1))
+
+ # induce some certainty about what the shape of the steps are
+ @property
+ def d_gn(self):
+ return self._d_gn
+ @d_gn.setter
+ def d_gn(self, val):
+ if val is not None:
+ val = val.reshape((-1, 1))
+ self._d_gn = val
+
+ @property
+ def d_sd(self):
+ return self._d_sd
+ @d_sd.setter
+ def d_sd(self, val):
+ if val is not None:
+ val = val.reshape((-1, 1))
+ self._d_sd = val
+
+ @property
+ def d_dl(self):
+ return self._d_dl
+ @d_dl.setter
+ def d_dl(self, val):
+ if val is not None:
+ val = val.reshape((-1, 1))
+ self._d_dl = val
+
+ @property
+ def step(self):
+ return self.d_dl.reshape((-1, 1))
+ @property
+ def step_size(self):
+ return np.linalg.norm(self.d_dl)
+
+ def start_iteration(self):
+ self.iteration += 1
+ pif('beginning iteration %d' % (self.iteration,))
+ self.d_sd = (np.linalg.norm(self.g)**2 / np.linalg.norm(self.J.dot(self.g))**2 * self.g).ravel()
+ self.d_gn = None
+
+ @property
+ def r(self):
+ '''r is the residual at the current p'''
+ return self._r
+ @r.setter
+ def r(self, val):
+ self._r = val.copy().reshape((-1, 1))
+ self.updateAg()
+
+ def updateAg(self):
+ tm = timer()
+ pif('updating A and g...')
+ JT = self.J.T
+ self.A = JT.dot(self.J)
+ self.g = JT.dot(-self.r).reshape((-1, 1))
+ pif('A and g updated in %.2fs' % tm())
+
+ def update_step(self):
+ # if the Cauchy point is outside the trust region,
+ # take that direction but only to the edge of the trust region
+ if self.delta is not None and np.linalg.norm(self.d_sd) >= self.delta:
+ pif('PROGRESS: Using stunted cauchy')
+ self.d_dl = np.array(self.delta/np.linalg.norm(self.d_sd) * self.d_sd).ravel()
+ else:
+ if self.d_gn is None:
+ # We only need to compute this once per iteration
+ self.updateGN()
+ # if the gauss-newton solution is within the trust region, use it
+ if self.delta is None or np.linalg.norm(self.d_gn) <= self.delta:
+ pif('PROGRESS: Using gauss-newton solution')
+ self.d_dl = np.array(self.d_gn).ravel()
+ if self.delta is None:
+ self.delta = np.linalg.norm(self.d_gn)
+ else: # between cauchy step and gauss-newton step
+ pif('PROGRESS: between cauchy and gauss-newton')
+ # apply step
+ self.d_dl = self.d_sd + self.beta_multiplier * (self.d_gn - self.d_sd)
+
+ @property
+ def beta_multiplier(self):
+ delta_sq = self.delta**2
+ diff = self.d_gn - self.d_sd
+ sqnorm_sd = np.linalg.norm(self.d_sd)**2
+ pnow = diff.T.dot(diff)*delta_sq + self.d_gn.T.dot(self.d_sd)**2 - np.linalg.norm(self.d_gn)**2 * sqnorm_sd
+ return float(delta_sq - sqnorm_sd) / float((diff).T.dot(self.d_sd) + np.sqrt(pnow))
+
+ def updateGN(self):
+ tm = timer()
+ if sp.issparse(self.A):
+ self.A.eliminate_zeros()
+ pif('sparse solve...sparsity infill is %.3f%% (hessian %dx%d)' % (100. * self.A.nnz / (self.A.shape[0] * self.A.shape[1]), self.A.shape[0], self.A.shape[1]))
+ if self.g.size > 1:
+ self.d_gn = self.solve(self.A, self.g).ravel()
+ if np.any(np.isnan(self.d_gn)) or np.any(np.isinf(self.d_gn)):
+ from scipy.sparse.linalg import lsqr
+ warnings.warn("sparse solve failed, falling back to lsqr")
+ self.d_gn = lsqr(self.A, self.g)[0].ravel()
+ else:
+ self.d_gn = np.atleast_1d(self.g.ravel()[0]/self.A[0,0])
+ pif('sparse solve...done in %.2fs' % tm())
+ else:
+ pif('dense solve...')
+ try:
+ self.d_gn = np.linalg.solve(self.A, self.g).ravel()
+ except Exception:
+ warnings.warn("dense solve failed, falling back to lsqr")
+ self.d_gn = np.linalg.lstsq(self.A, self.g)[0].ravel()
+ pif('dense solve...done in %.2fs' % tm())
+
+ def updateJ(self, obj):
+ tm = timer()
+ pif('computing Jacobian...')
+ self.J = obj.J
+ if self.J is None:
+ raise Exception("Computing Jacobian failed!")
+ if sp.issparse(self.J):
+ tm2 = timer()
+ self.J = self.J.tocsr()
+ pif('converted to csr in {}secs'.format(tm2()))
+ assert(self.J.nnz > 0)
+ elif ch.VERBOSE:
+ nonzero = np.count_nonzero(self.J)
+ pif('Jacobian dense with sparsity %.3f' % (nonzero/self.J.size))
+ pif('Jacobian (%dx%d) computed in %.2fs' % (self.J.shape[0], self.J.shape[1], tm()))
+ if self.J.shape[1] != self.p.size:
+ raise Exception('Jacobian size mismatch with objective input')
+ return self.J
+
+ class Trial(object):
+ '''
+ Inside each iteration of dogleg we propose a step and check to see if it's actually
+ an improvement before we accept it. This class encapsulates that trial and the
+ testing to see if it is actually an improvement.
+
+ There will be one instance of Trial per iteration in dogleg.
+ '''
+ def __init__(self, proposed_r, state):
+ self.r = proposed_r
+ self.state = state
+ # rho is the ratio of...
+ # (improvement in SSE) / (predicted improvement in SSE)
+ self.rho = np.linalg.norm(state.r)**2 - np.linalg.norm(proposed_r)**2
+ if self.rho > 0:
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore',category=RuntimeWarning)
+ predicted_improvement = 2. * state.g.T.dot(state.d_dl) - state.d_dl.T.dot(state.A.dot(state.d_dl))
+ self.rho /= predicted_improvement
+
+ @property
+ def is_improvement(self):
+ return self.rho > 0
+
+ @property
+ def improvement(self):
+ return (np.linalg.norm(self.state.r)**2 - np.linalg.norm(self.r)**2) / np.linalg.norm(self.state.r)**2
+
+ def trial_r(self, proposed_r):
+ return self.Trial(proposed_r, self)
+
+ def updateRadius(self, rho, lb=.05, ub=.9):
+ if rho > ub:
+ self.delta = max(self.delta, 2.5*np.linalg.norm(self.d_dl))
+ elif rho < lb:
+ self.delta *= .25
+
+
+def minimize_dogleg(obj, free_variables, on_step=None,
+ maxiter=200, max_fevals=np.inf, sparse_solver='spsolve',
+ disp=True, e_1=1e-15, e_2=1e-15, e_3=0., delta_0=None,
+ treat_as_dense=False):
+ """"Nonlinear optimization using Powell's dogleg method.
+ See Lourakis et al, 2005, ICCV '05, "Is Levenberg-Marquardt the
+ Most Efficient Optimization for Implementing Bundle Adjustment?":
+ http://www.ics.forth.gr/cvrl/publications/conferences/0201-P0401-lourakis-levenberg.pdf
+
+ e_N are stopping conditions:
+ e_1 is gradient magnatude threshold
+ e_2 is step size magnatude threshold
+ e_3 is improvement threshold (as a ratio; 0.1 means it must improve by 10%% at each step)
+
+ maxiter and max_fevals are also stopping conditions. Note that they're not quite the same,
+ as an iteration may evaluate the function more than once.
+
+ sparse_solver is the solver to use to calculate the Gauss-Newton step in the common case
+ that the Jacobian is sparse. It can be 'spsolve' (in which case scipy.sparse.linalg.spsolve
+ will be used), 'cg' (in which case scipy.sparse.linalg.cg will be used), or any callable
+ that matches the api of scipy.sparse.linalg.spsolve to solve `A x = b` for x where A is sparse.
+
+ cg, uses a Conjugate Gradient method, and will be faster if A is sparse but x is dense.
+ spsolve will be faster if x is also sparse.
+
+ delta_0 defines the initial trust region. Generally speaking, if this is set too low then
+ the optimization will never really go anywhere (to small a trust region to make any real
+ progress before running out of iterations) and if it's set too high then the optimization
+ will diverge immidiately and go wild (such a large trust region that the initial step so
+ far overshoots that it can't recover). If it's left as None, it will be automatically
+ estimated on the first iteration; it's always updated at each iteration, so this is treated
+ only as an initialization.
+
+ handle_as_dense explicitly converts all Jacobians of obj to dense matrices
+ """
+
+
+ solve = setup_sparse_solver(sparse_solver)
+ obj, callback = setup_objective(obj, free_variables, on_step=on_step, disp=disp,
+ make_dense=treat_as_dense)
+
+ state = DoglegState(delta=delta_0, solve=solve)
+ state.p = obj.x.r
+
+ #inject profiler if in DEBUG mode
+ if ch.DEBUG:
+ from .monitor import DrWrtProfiler
+ obj.profiler = DrWrtProfiler(obj)
+
+ callback()
+ state.updateJ(obj)
+ state.r = obj.r
+
+ def stop(msg):
+ if not state.done:
+ pif(msg)
+ state.done = True
+
+ if np.linalg.norm(state.g, np.inf) < e_1:
+ stop('stopping because norm(g, np.inf) < %.2e' % e_1)
+ while not state.done:
+ state.start_iteration()
+ while True:
+ state.update_step()
+ if state.step_size <= e_2 * np.linalg.norm(state.p):
+ stop('stopping because of small step size (norm_dl < %.2e)' % (e_2 * np.linalg.norm(state.p)))
+ else:
+ tm = timer()
+ obj.x = state.p + state.step
+ trial = state.trial_r(obj.r)
+ pif('Residuals computed in %.2fs' % tm())
+ # if the objective function improved, update input parameter estimate.
+ # Note that the obj.x already has the new parms,
+ # and we should not set them again to the same (or we'll bust the cache)
+ if trial.is_improvement:
+ state.p = state.p + state.step
+ callback()
+ if e_3 > 0. and trial.improvement < e_3:
+ stop('stopping because improvement < %.1e%%' % (100*e_3))
+ else:
+ state.updateJ(obj)
+ state.r = trial.r
+ if np.linalg.norm(state.g, np.inf) < e_1:
+ stop('stopping because norm(g, np.inf) < %.2e' % e_1)
+ else: # Put the old parms back
+ obj.x = ch.Ch(state.p)
+ obj.on_changed('x') # copies from flat vector to free variables
+ # update our trust region
+ state.updateRadius(trial.rho)
+ if state.delta <= e_2*np.linalg.norm(state.p):
+ stop('stopping because trust region is too small')
+ if state.done or trial.is_improvement or (obj.fevals >= max_fevals):
+ break
+ if state.iteration >= maxiter:
+ stop('stopping because max number of user-specified iterations (%d) has been met' % maxiter)
+ elif obj.fevals >= max_fevals:
+ stop('stopping because max number of user-specified func evals (%d) has been met' % max_fevals)
+ return obj.free_variables
diff --git a/chumpy-0.70/chumpy/reordering.py b/chumpy-0.70/chumpy/reordering.py
new file mode 100644
index 00000000..0ab1cd34
--- /dev/null
+++ b/chumpy-0.70/chumpy/reordering.py
@@ -0,0 +1,454 @@
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+from .ch import Ch
+import numpy as np
+from .utils import row, col
+import scipy.sparse as sp
+import weakref
+
+__all__ = ['sort', 'tile', 'repeat', 'transpose', 'rollaxis', 'swapaxes', 'reshape', 'Select',
+ 'atleast_1d', 'atleast_2d', 'atleast_3d', 'squeeze', 'expand_dims', 'fliplr', 'flipud',
+ 'concatenate', 'vstack', 'hstack', 'dstack', 'ravel', 'diag', 'diagflat', 'roll', 'rot90']
+
+# Classes deriving from "Permute" promise to only reorder/reshape
+class Permute(Ch):
+ pass
+
+def ravel(a, order='C'):
+ assert(order=='C')
+ if isinstance (a, np.ndarray):
+ self = Ch(a)
+
+ return reshape(a=a, newshape=(-1,))
+
+class Reorder(Permute):
+ dterms = 'a',
+
+ def on_changed(self, which):
+ if not hasattr(self, 'dr_lookup'):
+ self.dr_lookup = {}
+
+ def compute_r(self):
+ return self.reorder(self.a.r)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.a:
+ if False:
+ from scipy.sparse.linalg.interface import LinearOperator
+ return LinearOperator((self.size, wrt.size), lambda x : self.reorder(x.reshape(self.a.shape)).ravel())
+ else:
+ a = self.a
+ asz = a.size
+ ashape = a.shape
+ key = self.unique_reorder_id()
+ if key not in self.dr_lookup or key is None:
+ JS = self.reorder(np.arange(asz).reshape(ashape))
+ IS = np.arange(JS.size)
+ data = np.ones_like(IS)
+ shape = JS.shape
+ self.dr_lookup[key] = sp.csc_matrix((data, (IS, JS.ravel())), shape=(self.r.size, wrt.r.size))
+ return self.dr_lookup[key]
+
+class Sort(Reorder):
+ dterms = 'a'
+ terms = 'axis', 'kind', 'order'
+
+ def reorder(self, a): return np.sort(a, self.axis, self.kind, self.order)
+ def unique_reorder_id(self): return None
+
+def sort(a, axis=-1, kind='quicksort', order=None):
+ return Sort(a=a, axis=axis, kind=kind, order=order)
+
+
+class Tile(Reorder):
+ dterms = 'a',
+ terms = 'reps',
+ term_order = 'a', 'reps'
+
+ def reorder(self, a): return np.tile(a, self.reps)
+ def unique_reorder_id(self): return (self.a.shape, tuple(self.reps))
+
+def tile(A, reps):
+ return Tile(a=A, reps=reps)
+
+
+class Diag(Reorder):
+ dterms = 'a',
+ terms = 'k',
+
+ def reorder(self, a): return np.diag(a, self.k)
+ def unique_reorder_id(self): return (self.a.shape, self.k)
+
+def diag(v, k=0):
+ return Diag(a=v, k=k)
+
+class DiagFlat(Reorder):
+ dterms = 'a',
+ terms = 'k',
+
+ def reorder(self, a): return np.diagflat(a, self.k)
+ def unique_reorder_id(self): return (self.a.shape, self.k)
+
+def diagflat(v, k=0):
+ return DiagFlat(a=v, k=k)
+
+
+class Repeat(Reorder):
+ dterms = 'a',
+ terms = 'repeats', 'axis'
+
+ def reorder(self, a): return np.repeat(a, self.repeats, self.axis)
+ def unique_reorder_id(self): return (self.repeats, self.axis)
+
+def repeat(a, repeats, axis=None):
+ return Repeat(a=a, repeats=repeats, axis=axis)
+
+class transpose(Reorder):
+ dterms = 'a'
+ terms = 'axes'
+ term_order = 'a', 'axes'
+
+ def reorder(self, a): return np.require(np.transpose(a, axes=self.axes), requirements='C')
+ def unique_reorder_id(self): return (self.a.shape, None if self.axes is None else tuple(self.axes))
+ def on_changed(self, which):
+ if not hasattr(self, 'axes'):
+ self.axes = None
+ super(self.__class__, self).on_changed(which)
+
+class rollaxis(Reorder):
+ dterms = 'a'
+ terms = 'axis', 'start'
+ term_order = 'a', 'axis', 'start'
+
+ def reorder(self, a): return np.rollaxis(a, axis=self.axis, start=self.start)
+ def unique_reorder_id(self): return (self.a.shape, self.axis, self.start)
+ def on_changed(self, which):
+ if not hasattr(self, 'start'):
+ self.start = 0
+ super(self.__class__, self).on_changed(which)
+
+class swapaxes(Reorder):
+ dterms = 'a'
+ terms = 'axis1', 'axis2'
+ term_order = 'a', 'axis1', 'axis2'
+
+ def reorder(self, a): return np.swapaxes(a, axis1=self.axis1, axis2=self.axis2)
+ def unique_reorder_id(self): return (self.a.shape, self.axis1, self.axis2)
+
+
+
+class Roll(Reorder):
+ dterms = 'a',
+ terms = 'shift', 'axis'
+ term_order = 'a', 'shift', 'axis'
+
+ def reorder(self, a): return np.roll(a, self.shift, self.axis)
+ def unique_reorder_id(self): return (self.shift, self.axis)
+
+def roll(a, shift, axis=None):
+ return Roll(a, shift, axis)
+
+class Rot90(Reorder):
+ dterms = 'a',
+ terms = 'k',
+
+ def reorder(self, a): return np.rot90(a, self.k)
+ def unique_reorder_id(self): return (self.a.shape, self.k)
+
+def rot90(m, k=1):
+ return Rot90(a=m, k=k)
+
+class Reshape(Permute):
+ dterms = 'a',
+ terms = 'newshape',
+ term_order= 'a', 'newshape'
+
+ def compute_r(self):
+ return self.a.r.reshape(self.newshape)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.a:
+ return sp.eye(self.a.size, self.a.size)
+ #return self.a.dr_wrt(wrt)
+
+# def reshape(a, newshape):
+# if isinstance(a, Reshape) and a.newshape == newshape:
+# return a
+# return Reshape(a=a, newshape=newshape)
+def reshape(a, newshape):
+ while isinstance(a, Reshape):
+ a = a.a
+ return Reshape(a=a, newshape=newshape)
+
+# class And(Ch):
+# dterms = 'x1', 'x2'
+#
+# def compute_r(self):
+# if True:
+# needs_work = [self.x1, self.x2]
+# done = []
+# while len(needs_work) > 0:
+# todo = needs_work.pop()
+# if isinstance(todo, And):
+# needs_work += [todo.x1, todo.x2]
+# else:
+# done = [todo] + done
+# return np.concatenate([d.r.ravel() for d in done])
+# else:
+# return np.concatenate((self.x1.r.ravel(), self.x2.r.ravel()))
+#
+# # This is only here for reverse mode to work.
+# # Most of the time, the overridden dr_wrt is callpath gets used.
+# def compute_dr_wrt(self, wrt):
+#
+# if wrt is not self.x1 and wrt is not self.x2:
+# return
+#
+# input_len = wrt.r.size
+# x1_len = self.x1.r.size
+# x2_len = self.x2.r.size
+#
+# mtxs = []
+# if wrt is self.x1:
+# mtxs.append(sp.spdiags(np.ones(x1_len), 0, x1_len, x1_len))
+# else:
+# mtxs.append(sp.csc_matrix((x1_len, input_len)))
+#
+# if wrt is self.x2:
+# mtxs.append(sp.spdiags(np.ones(x2_len), 0, x2_len, x2_len))
+# else:
+# mtxs.append(sp.csc_matrix((x2_len, input_len)))
+#
+#
+# if any([sp.issparse(mtx) for mtx in mtxs]):
+# result = sp.vstack(mtxs, format='csc')
+# else:
+# result = np.vstack(mtxs)
+#
+# return result
+#
+# def dr_wrt(self, wrt, want_stacks=False, reverse_mode=False):
+# self._call_on_changed()
+#
+# input_len = wrt.r.size
+# x1_len = self.x1.r.size
+# x2_len = self.x2.r.size
+#
+# mtxs = []
+# if wrt is self.x1:
+# mtxs.append(sp.spdiags(np.ones(x1_len), 0, x1_len, x1_len))
+# else:
+# if isinstance(self.x1, And):
+# tmp_mtxs = self.x1.dr_wrt(wrt, want_stacks=True, reverse_mode=reverse_mode)
+# for mtx in tmp_mtxs:
+# mtxs.append(mtx)
+# else:
+# mtxs.append(self.x1.dr_wrt(wrt, reverse_mode=reverse_mode))
+# if mtxs[-1] is None:
+# mtxs[-1] = sp.csc_matrix((x1_len, input_len))
+#
+# if wrt is self.x2:
+# mtxs.append(sp.spdiags(np.ones(x2_len), 0, x2_len, x2_len))
+# else:
+# if isinstance(self.x2, And):
+# tmp_mtxs = self.x2.dr_wrt(wrt, want_stacks=True, reverse_mode=reverse_mode)
+# for mtx in tmp_mtxs:
+# mtxs.append(mtx)
+# else:
+# mtxs.append(self.x2.dr_wrt(wrt, reverse_mode=reverse_mode))
+# if mtxs[-1] is None:
+# mtxs[-1] = sp.csc_matrix((x2_len, input_len))
+#
+# if want_stacks:
+# return mtxs
+# else:
+# if any([sp.issparse(mtx) for mtx in mtxs]):
+# result = sp.vstack(mtxs, format='csc')
+# else:
+# result = np.vstack(mtxs)
+#
+# return result
+
+class Select(Permute):
+ terms = ['idxs', 'preferred_shape']
+ dterms = ['a']
+ term_order = 'a', 'idxs', 'preferred_shape'
+
+ def compute_r(self):
+ result = self.a.r.ravel()[self.idxs].copy()
+ if hasattr(self, 'preferred_shape'):
+ return result.reshape(self.preferred_shape)
+ else:
+ return result
+
+ def compute_dr_wrt(self, obj):
+ if obj is self.a:
+ if not hasattr(self, '_dr_cached'):
+ IS = np.arange(len(self.idxs))
+ JS = self.idxs.ravel()
+ ij = np.vstack((row(IS), row(JS)))
+ data = np.ones(len(self.idxs))
+ self._dr_cached = sp.csc_matrix((data, ij), shape=(len(self.idxs), np.prod(self.a.shape)))
+ return self._dr_cached
+
+ def on_changed(self, which):
+ if hasattr(self, '_dr_cached'):
+ if 'idxs' in which or self.a.r.size != self._dr_cached.shape[1]:
+ del self._dr_cached
+
+
+
+class AtleastNd(Ch):
+ dterms = 'x'
+ terms = 'ndims'
+
+ def compute_r(self):
+ xr = self.x.r
+ if self.ndims == 1:
+ target_shape = np.atleast_1d(xr).shape
+ elif self.ndims == 2:
+ target_shape = np.atleast_2d(xr).shape
+ elif self.ndims == 3:
+ target_shape = np.atleast_3d(xr).shape
+ else:
+ raise Exception('Need ndims to be 1, 2, or 3.')
+
+ return xr.reshape(target_shape)
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.x:
+ return 1
+
+def atleast_nd(ndims, *arys):
+ arys = [AtleastNd(x=ary, ndims=ndims) for ary in arys]
+ return arys if len(arys) > 1 else arys[0]
+
+def atleast_1d(*arys):
+ return atleast_nd(1, *arys)
+
+def atleast_2d(*arys):
+ return atleast_nd(2, *arys)
+
+def atleast_3d(*arys):
+ return atleast_nd(3, *arys)
+
+def squeeze(a, axis=None):
+ if isinstance(a, np.ndarray):
+ return np.squeeze(a, axis)
+ shape = np.squeeze(a.r, axis).shape
+ return a.reshape(shape)
+
+def expand_dims(a, axis):
+ if isinstance(a, np.ndarray):
+ return np.expand_dims(a, axis)
+ shape = np.expand_dims(a.r, axis).shape
+ return a.reshape(shape)
+
+def fliplr(m):
+ return m[:,::-1]
+
+def flipud(m):
+ return m[::-1,...]
+
+class Concatenate(Ch):
+
+ def on_changed(self, which):
+ if not hasattr(self, 'dr_cached'):
+ self.dr_cached = weakref.WeakKeyDictionary()
+
+ @property
+ def our_terms(self):
+ if not hasattr(self, '_our_terms'):
+ self._our_terms = [getattr(self, s) for s in self.dterms]
+ return self._our_terms
+
+ def __getstate__(self):
+ # Have to get rid of WeakKeyDictionaries for serialization
+ if hasattr(self, 'dr_cached'):
+ del self.dr_cached
+ return super(self.__class__, self).__getstate__()
+
+ def compute_r(self):
+ return np.concatenate([t.r for t in self.our_terms], axis=self.axis)
+
+ @property
+ def everything(self):
+ if not hasattr(self, '_everything'):
+ self._everything = np.arange(self.r.size).reshape(self.r.shape)
+ self._everything = np.swapaxes(self._everything, self.axis, 0)
+ return self._everything
+
+ def compute_dr_wrt(self, wrt):
+ if not hasattr(self, 'dr_cached'):
+ self.dr_cached = weakref.WeakKeyDictionary()
+ if wrt in self.dr_cached and self.dr_cached[wrt] is not None:
+ return self.dr_cached[wrt]
+
+ if wrt not in self.our_terms:
+ return
+
+ _JS = np.arange(wrt.size)
+ _data = np.ones(wrt.size)
+
+ IS = []
+ JS = []
+ data = []
+
+ offset = 0
+ for term in self.our_terms:
+ tsz = term.shape[self.axis]
+ if term is wrt:
+ JS += [_JS]
+ data += [_data]
+ IS += [np.swapaxes(self.everything[offset:offset+tsz], self.axis, 0).ravel()]
+ offset += tsz
+ IS = np.concatenate(IS).ravel()
+ JS = np.concatenate(JS).ravel()
+ data = np.concatenate(data)
+
+ res = sp.csc_matrix((data, (IS, JS)), shape=(self.r.size, wrt.size))
+
+ if len(list(self._parents.keys())) != 1:
+ self.dr_cached[wrt] = res
+ else:
+ self.dr_cached[wrt] = None
+
+ return res
+
+
+def expand_concatenates(mtxs, axis=0):
+ mtxs = list(mtxs)
+ done = []
+ while len(mtxs) > 0:
+ mtx = mtxs.pop(0)
+ if isinstance(mtx, Concatenate) and mtx.axis == axis:
+ mtxs = [getattr(mtx, s) for s in mtx.dterms] + mtxs
+ else:
+ done.append(mtx)
+ return done
+
+
+def concatenate(mtxs, axis=0, **kwargs):
+
+ mtxs = expand_concatenates(mtxs, axis)
+
+ result = Concatenate(**kwargs)
+ result.dterms = []
+ for i, mtx in enumerate(mtxs):
+ result.dterms.append('m%d' % (i,))
+ setattr(result, result.dterms[-1], mtx)
+ result.axis = axis
+ return result
+
+def hstack(mtxs, **kwargs):
+ return concatenate(mtxs, axis=1, **kwargs)
+
+def vstack(mtxs, **kwargs):
+ return concatenate([atleast_2d(m) for m in mtxs], axis=0, **kwargs)
+
+def dstack(mtxs, **kwargs):
+ return concatenate([atleast_3d(m) for m in mtxs], axis=2, **kwargs)
diff --git a/chumpy-0.70/chumpy/test_ch.py b/chumpy-0.70/chumpy/test_ch.py
new file mode 100755
index 00000000..0b2ad108
--- /dev/null
+++ b/chumpy-0.70/chumpy/test_ch.py
@@ -0,0 +1,621 @@
+#!/usr/bin/env python
+# encoding: utf-8
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+import time
+
+import unittest
+import numpy as np
+import scipy.sparse as sp
+
+from . import ch
+
+class TestCh(unittest.TestCase):
+
+
+ def test_cachehits(self):
+ """Test how many nodes are visited when cache is cleared.
+ If the number of hits changes, it has to be carefully
+ looked at to make sure that correctness and performance
+ don't get messed up by a change."""
+
+ a = ch.array(1)
+ b = ch.array(2)
+ c = a
+ for i in range(10):
+ c = a + c + b
+
+ c.dr_wrt(a)
+ c.dr_wrt(b)
+ self.assertEqual(a.clear_cache() + b.clear_cache(), 59)
+ c.dr_wrt(a)
+ c.dr_wrt(b)
+ self.assertEqual(a.clear_cache(123) + b.clear_cache(123), 41)
+
+ def test_nested_concatenate(self):
+ aa = ch.arange(3)
+ bb = ch.arange(4)
+ cc = ch.arange(5)
+
+ result = ch.concatenate((ch.concatenate((aa,bb)),cc))
+ self.assertTrue(result.m0 is aa)
+ self.assertTrue(result.m1 is bb)
+ self.assertTrue(result.m2 is cc)
+
+ self.assertTrue(result.dr_wrt(aa).nnz > 0)
+ self.assertTrue(result.dr_wrt(bb).nnz > 0)
+ self.assertTrue(result.dr_wrt(cc).nnz > 0)
+
+ def test_nandivide(self):
+ foo = ch.array(np.random.randn(16).reshape((4,4)))
+ bar = ch.array(np.random.randn(16).reshape((4,4)))
+ bar[2,2] = 0
+ self.assertEqual(ch.NanDivide(foo,bar)[2,2].r, 0.)
+ foo[2,2] = 0
+ self.assertEqual(ch.NanDivide(foo,bar)[2,2].r, 0.)
+
+ def test_casting(self):
+ for fn in float, int:
+ self.assertEqual(fn(np.array(5)), fn(ch.array(5)))
+ self.assertEqual(fn(np.array([[5]])), fn(ch.array([[5]])))
+
+ def test_tensordot(self):
+ an = np.arange(60.).reshape(3,4,5)
+ bn = np.arange(24.).reshape(4,3,2)
+ cn = np.tensordot(an,bn, axes=([1,0],[0,1]))
+
+ ac = ch.arange(60.).reshape(3,4,5)
+ bc = ch.arange(24.).reshape(4,3,2)
+ cc = ch.tensordot(ac,bc, axes=([1,0],[0,1]))
+
+ cc.r
+ cc.dr_wrt(ac)
+ cc.dr_wrt(bc)
+ #print cn
+
+ def test_make_sure_is_double(self):
+ x = ch.array([0])
+ self.assertTrue(isinstance(x.r[0], np.float64))
+
+ def test_cross(self):
+ aa = ch.random.randn(30).reshape((10,3))
+ bb = ch.random.randn(30).reshape((10,3))
+
+ cross_ch = ch.cross(aa, bb)
+ cross_np = np.cross(aa.r, bb.r)
+
+ # print cross_ch.r
+ # print cross_np
+
+ eps = 1.0
+ step = (np.random.rand(30) - .5).reshape((10,3)) * eps
+
+ gt_diff = np.cross(aa.r, bb.r+step) - cross_np
+ pr_diff = cross_ch.dr_wrt(bb).dot(step.ravel())
+ # print gt_diff
+ # print pr_diff
+ # print np.max(np.abs(gt_diff.ravel()-pr_diff.ravel()))
+ self.assertTrue(1e-14 > np.max(np.abs(gt_diff.ravel()-pr_diff.ravel())))
+
+ gt_diff = np.cross(aa.r+step, bb.r) - cross_np
+ pr_diff = cross_ch.dr_wrt(aa).dot(step.ravel())
+ #print gt_diff
+ # print pr_diff
+ # print np.max(np.abs(gt_diff.ravel()-pr_diff.ravel()))
+ self.assertTrue(1e-14 > np.max(np.abs(gt_diff.ravel()-pr_diff.ravel())))
+
+ def test_dr_wrt_selection(self):
+ aa = ch.arange(10,20)
+ bb = ch.arange(1,11)
+ cc = aa * bb + aa + bb +2
+
+ dr0 = cc.dr_wrt(aa[4:6])
+ dr1 = cc.dr_wrt(aa)[:,4:6]
+ self.assertTrue((dr0 - dr1).nnz == 0)
+
+ dr0 = cc.dr_wrt(bb[5:8])
+ dr1 = cc.dr_wrt(bb)[:,5:8]
+ self.assertTrue((dr0 - dr1).nnz == 0)
+
+
+ def test_sum_mean_std_var(self):
+ for fn in [ch.sum, ch.mean, ch.var, ch.std]:
+
+ # Create fake input and differences in input space
+ data1 = ch.ones((3,4,7,2))
+ data2 = ch.array(data1.r + .1 * np.random.rand(data1.size).reshape(data1.shape))
+ diff = data2.r - data1.r
+
+ # Compute outputs
+ result1 = fn(data1, axis=2)
+ result2 = fn(data2, axis=2)
+
+ # Empirical and predicted derivatives
+ gt = result2.r - result1.r
+ pred = result1.dr_wrt(data1).dot(diff.ravel()).reshape(gt.shape)
+
+ #print np.max(np.abs(gt - pred))
+
+ if fn in [ch.std, ch.var]:
+ self.assertTrue(1e-2 > np.max(np.abs(gt - pred)))
+ else:
+ self.assertTrue(1e-14 > np.max(np.abs(gt - pred)))
+ # test caching
+ dr0 = result1.dr_wrt(data1)
+ data1[:] = np.random.randn(data1.size).reshape(data1.shape)
+ self.assertTrue(result1.dr_wrt(data1) is dr0) # changing values shouldn't force recompute
+ result1.axis=1
+ self.assertTrue(result1.dr_wrt(data1) is not dr0)
+
+ self.assertEqual(ch.mean(ch.eye(3),axis=1).ndim, np.mean(np.eye(3),axis=1).ndim)
+ self.assertEqual(ch.mean(ch.eye(3),axis=0).ndim, np.mean(np.eye(3),axis=0).ndim)
+ self.assertEqual(ch.sum(ch.eye(3),axis=1).ndim, np.sum(np.eye(3),axis=1).ndim)
+ self.assertEqual(ch.sum(ch.eye(3),axis=0).ndim, np.sum(np.eye(3),axis=0).ndim)
+
+
+
+ def test_cumsum(self):
+ a = ch.array([1.,5.,3.,7.])
+ cs = ch.cumsum(a)
+ r1 = cs.r
+ dr = cs.dr_wrt(a)
+ diff = (ch.random.rand(4)-.5)*.1
+ a.x += diff.r
+ pred = dr.dot(diff.r)
+ gt = cs.r - r1
+ self.assertTrue(1e-13 > np.max(np.abs(gt - pred)))
+
+
+ def test_iteration_cache(self):
+ """ Each time you set an attribute, the cache (of r's and dr's) of
+ ancestors is cleared. Because children share ancestors, this means
+ these can be cleared multiple times unnecessarily; in some cases,
+ where lots of objects exist, this cache clearing can actually be a bottleneck.
+
+ Therefore, the concept of an iteration was added; intended to be used in
+ an optimization setting (see optimization.py) and in the set() method, it
+ avoids such redundant clearing of cache."""
+
+ a, b, c = ch.Ch(1), ch.Ch(2), ch.Ch(3)
+ x = a+b
+ y = x+c
+ self.assertTrue(y.r[0]==6)
+
+ a.__setattr__('x', 10, 1)
+ self.assertTrue(y.r == 15)
+ a.__setattr__('x', 100, 1)
+ self.assertTrue(y.r == 15)
+ a.__setattr__('x', 100, 2)
+ self.assertTrue(y.r == 105)
+
+ a, b, c = ch.array([1]), ch.array([2]), ch.array([3])
+ x = a+b
+ y = x+c
+ self.assertTrue(y.r[0]==6)
+
+ a.__setattr__('x', np.array([10]), 1)
+ self.assertTrue(y.r[0] == 15)
+ a.__setattr__('x', np.array(100), 1)
+ self.assertTrue(y.r[0] == 15)
+ a.__setattr__('x', np.array(100), 2)
+ self.assertTrue(y.r[0] == 105)
+ a.__setitem__(list(range(0,1)), np.array(200), 2)
+ self.assertTrue(y.r[0] == 105)
+ a.__setitem__(list(range(0,1)), np.array(200), 3)
+ self.assertTrue(y.r[0] == 205)
+
+
+
+ def test_stacking(self):
+
+ a1 = ch.Ch(np.arange(10).reshape(2,5))
+ b1 = ch.Ch(np.arange(20).reshape(4,5))
+ c1 = ch.vstack((a1,b1))
+ c1_check = np.vstack((a1.r, b1.r))
+ residuals1 = (c1_check - c1.r).ravel()
+
+
+ a2 = ch.Ch(np.arange(10).reshape(5,2))
+ b2 = ch.Ch(np.arange(20).reshape(5,4))
+ c2 = ch.hstack((a2,b2))
+ c2_check = np.hstack((a2.r, b2.r))
+ residuals2 = (c2_check - c2.r).ravel()
+
+ self.assertFalse(np.any(residuals1))
+ self.assertFalse(np.any(residuals2))
+
+ d0 = ch.array(np.arange(60).reshape((10,6)))
+ d1 = ch.vstack((d0[:4], d0[4:]))
+ d2 = ch.hstack((d1[:,:3], d1[:,3:]))
+ tmp = d2.dr_wrt(d0).todense()
+ diff = tmp - np.eye(tmp.shape[0])
+ self.assertFalse(np.any(diff.ravel()))
+
+
+
+ #def test_drs(self):
+ # a = ch.Ch(2)
+ # b = ch.Ch(3)
+ # c = a * b
+ # print c.dr_wrt(a)
+ # print c.compute_drs_wrt(a).r
+
+ @unittest.skip('We are using LinearOperator for this for now. Might change back though.')
+ def test_reorder_caching(self):
+ a = ch.Ch(np.zeros(8).reshape((4,2)))
+ b = a.T
+ dr0 = b.dr_wrt(a)
+ a.x = a.x + 1.
+ dr1 = b.dr_wrt(a)
+ self.assertTrue(dr0 is dr1)
+ a.x = np.zeros(4).reshape((2,2))
+ dr2 = b.dr_wrt(a)
+ self.assertTrue(dr2 is not dr1)
+
+ def test_transpose(self):
+ from .utils import row, col
+ from copy import deepcopy
+ for which in ('C', 'F'): # test in fortran and contiguous mode
+ a = ch.Ch(np.require(np.zeros(8).reshape((4,2)), requirements=which))
+ b = a.T
+
+ b1 = b.r.copy()
+ #dr = b.dr_wrt(a).copy()
+ dr = deepcopy(b.dr_wrt(a))
+
+ diff = np.arange(a.size).reshape(a.shape)
+ a.x = np.require(a.r + diff, requirements=which)
+ b2 = b.r.copy()
+
+ diff_pred = dr.dot(col(diff)).ravel()
+ diff_emp = (b2 - b1).ravel()
+ np.testing.assert_array_equal(diff_pred, diff_emp)
+
+
+ def test_unary(self):
+ fns = [ch.exp, ch.log, ch.sin, ch.arcsin, ch.cos, ch.arccos, ch.tan, ch.arctan, ch.negative, ch.square, ch.sqrt, ch.abs, ch.reciprocal]
+
+ eps = 1e-8
+ for f in fns:
+
+ x0 = ch.Ch(.25)
+ x1 = ch.Ch(x0.r+eps)
+
+ pred = f(x0).dr_wrt(x0)
+ empr = (f(x1).r - f(x0).r) / eps
+
+ # print pred
+ # print empr
+ if f is ch.reciprocal:
+ self.assertTrue(1e-6 > np.abs(pred.ravel()[0] - empr.ravel()[0]))
+ else:
+ self.assertTrue(1e-7 > np.abs(pred.ravel()[0] - empr.ravel()[0]))
+
+
+ def test_serialization(self):
+ # The main challenge with serialization is the "_parents"
+ # attribute, which is a nonserializable WeakKeyDictionary.
+ # So we pickle/unpickle, change a child and verify the value
+ # at root, and verify that both children have parentage.
+ from six.moves import cPickle as pickle
+ tmp = ch.Ch(10) + ch.Ch(20)
+ tmp = pickle.loads(pickle.dumps(tmp))
+ tmp.b.x = 30
+ self.assertTrue(tmp.r[0] == 40)
+ self.assertTrue(list(tmp.a._parents.keys())[0] == tmp)
+ self.assertTrue(list(tmp.a._parents.keys())[0] == list(tmp.b._parents.keys())[0])
+
+ def test_chlambda1(self):
+ c1, c2, c3 = ch.Ch(1), ch.Ch(2), ch.Ch(3)
+ adder = ch.ChLambda(lambda x, y: x+y)
+ adder.x = c1
+ adder.y = c2
+ self.assertTrue(adder.r == 3)
+ adder.x = c2
+ self.assertTrue(adder.r == 4)
+ adder.x = c1
+ self.assertTrue(adder.r == 3)
+
+
+ def test_chlambda2(self):
+ passthrough = ch.ChLambda( lambda x : x)
+ self.assertTrue(passthrough.dr_wrt(passthrough.x) is not None)
+ passthrough.x = ch.Ch(123)
+ self.assertTrue(passthrough.dr_wrt(passthrough.x) is not None)
+
+ # It's probably not reasonable to expect this
+ # to work for ChLambda
+ #def test_chlambda3(self):
+ # c1, c2, c3 = ch.Ch(1), ch.Ch(2), ch.Ch(3)
+ # triple = ch.ChLambda( lambda x, y, z : x(y, z))
+ # triple.x = Add
+ # triple.y = c2
+ # triple.z = c3
+
+
+
+
+
+ def test_amax(self):
+ from .ch import amax
+ import numpy as np
+ arr = np.empty((5,2,3,7))
+ arr.flat[:] = np.sin(np.arange(arr.size)*1000.)
+ #arr = np.array(np.sin(np.arange(24)*10000.).reshape(2,3,4))
+
+ for axis in range(len(arr.shape)):
+ a = amax(a=arr, axis=axis)
+ pred = a.dr_wrt(a.a).dot(arr.ravel())
+ real = np.amax(arr, axis=axis).ravel()
+ self.assertTrue(np.max(np.abs(pred-real)) < 1e-10)
+
+ def test_maximum(self):
+ from .utils import row, col
+ from .ch import maximum
+
+ # Make sure that when we compare the max of two *identical* numbers,
+ # we get the right derivatives wrt both
+ the_max = maximum(ch.Ch(1), ch.Ch(1))
+ self.assertTrue(the_max.r.ravel()[0] == 1.)
+ self.assertTrue(the_max.dr_wrt(the_max.a)[0,0] == 1.)
+ self.assertTrue(the_max.dr_wrt(the_max.b)[0,0] == 1.)
+
+ # Now test given that all numbers are different, by allocating from
+ # a pool of randomly permuted numbers.
+ # We test combinations of scalars and 2d arrays.
+ rnd = np.asarray(np.random.permutation(np.arange(20)), np.float64)
+ c1 = ch.Ch(rnd[:6].reshape((2,3)))
+ c2 = ch.Ch(rnd[6:12].reshape((2,3)))
+ s1 = ch.Ch(rnd[12])
+ s2 = ch.Ch(rnd[13])
+
+ eps = .1
+ for first in [c1, s1]:
+ for second in [c2, s2]:
+ the_max = maximum(first, second)
+
+ for which_to_change in [first, second]:
+
+
+ max_r0 = the_max.r.copy()
+ max_r_diff = np.max(np.abs(max_r0 - np.maximum(first.r, second.r)))
+ self.assertTrue(max_r_diff == 0)
+ max_dr = the_max.dr_wrt(which_to_change).copy()
+ which_to_change.x = which_to_change.x + eps
+ max_r1 = the_max.r.copy()
+
+ emp_diff = (the_max.r - max_r0).ravel()
+ pred_diff = max_dr.dot(col(eps*np.ones(max_dr.shape[1]))).ravel()
+
+ #print 'comparing the following numbers/vectors:'
+ #print first.r
+ #print second.r
+ #print 'empirical vs predicted difference:'
+ #print emp_diff
+ #print pred_diff
+ #print '-----'
+
+ max_dr_diff = np.max(np.abs(emp_diff-pred_diff))
+ #print 'max dr diff: %.2e' % (max_dr_diff,)
+ self.assertTrue(max_dr_diff < 1e-14)
+
+
+ def test_shared(self):
+
+ chs = [ch.Ch(i) for i in range(10)]
+ vrs = [float(i) for i in range(10)]
+
+ func = lambda a : a[0]*a[1] + (a[2]*a[3])/a[4]
+
+ chained_result = func(chs).r
+ regular_result = func(vrs)
+
+ self.assertTrue(chained_result == regular_result)
+ #print chained_result
+ #print regular_result
+
+ chained_func = func(chs)
+ chained_func.replace(chs[0], ch.Ch(50))
+ vrs[0] = 50
+
+ chained_result = chained_func.r
+ regular_result = func(vrs)
+
+ self.assertTrue(chained_result == regular_result)
+ #print chained_result
+ #print regular_result
+
+
+ def test_matmatmult(self):
+ from .ch import dot
+ mtx1 = ch.Ch(np.arange(6).reshape((3,2)))
+ mtx2 = ch.Ch(np.arange(8).reshape((2,4))*10)
+
+ mtx3 = dot(mtx1, mtx2)
+ #print mtx1.r
+ #print mtx2.r
+ #print mtx3.r
+ #print mtx3.dr_wrt(mtx1).todense()
+ #print mtx3.dr_wrt(mtx2).todense()
+
+ for mtx in [mtx1, mtx2]:
+ oldval = mtx3.r.copy()
+ mtxd = mtx3.dr_wrt(mtx).copy()
+ mtx_diff = np.random.rand(mtx.r.size).reshape(mtx.r.shape)
+ mtx.x = mtx.r + mtx_diff
+ mtx_emp = mtx3.r - oldval
+ mtx_pred = mtxd.dot(mtx_diff.ravel()).reshape(mtx_emp.shape)
+
+ self.assertTrue(np.max(np.abs(mtx_emp - mtx_pred)) < 1e-11)
+
+
+ def test_ndim(self):
+ vs = [ch.Ch(np.random.randn(6).reshape(2,3)) for i in range(6)]
+ res = vs[0] + vs[1] - vs[2] * vs[3] / (vs[4] ** 2) ** vs[5]
+ self.assertTrue(res.shape[0]==2 and res.shape[1]==3)
+ res = (vs[0] + 1) + (vs[1] - 2) - (vs[2] * 3) * (vs[3] / 4) / (vs[4] ** 2) ** vs[5]
+ self.assertTrue(res.shape[0]==2 and res.shape[1]==3)
+ drs = [res.dr_wrt(v) for v in vs]
+
+
+ def test_indexing(self):
+ big = ch.Ch(np.arange(60).reshape((10,6)))
+ little = big[1:3, 3:6]
+ self.assertTrue(np.max(np.abs(little.r - np.array([[9,10,11],[15,16,17]]))) == 0)
+
+ little = big[5]
+ self.assertTrue(np.max(np.abs(little.r - np.arange(30, 36))) == 0)
+ self.assertTrue(np.max(np.abs(sp.coo_matrix(little.dr_wrt(big)).col - np.arange(30,36))) == 0)
+
+ little = big[2, 3]
+ self.assertTrue(little.r[0] == 15.0)
+
+ little = big[2, 3:5]
+ self.assertTrue(np.max(np.abs(little.r - np.array([15, 16]))) == 0.)
+ _ = little.dr_wrt(big)
+
+ # Tests assignment through reorderings
+ aa = ch.arange(4*4*4).reshape((4,4,4))[:3,:3,:3]
+ aa[0,1,2] = 100
+ self.assertTrue(aa[0,1,2].r[0] == 100)
+
+ # Tests assignment through reorderings (NaN's are a special case)
+ aa = ch.arange(9).reshape((3,3))
+ aa[1,1] = np.nan
+ self.assertTrue(np.isnan(aa.r[1,1]))
+ self.assertFalse(np.isnan(aa.r[0,0]))
+
+
+ def test_redundancy_removal(self):
+
+ for MT in [False, True]:
+ x1, x2 = ch.Ch(10), ch.Ch(20)
+ x1_plus_x2_1 = x1 + x2
+ x1_plus_x2_2 = x1 + x2
+ redundant_sum = (x1_plus_x2_1 + x1_plus_x2_2) * 2
+ redundant_sum.MT = MT
+
+ self.assertTrue(redundant_sum.a.a is not redundant_sum.a.b)
+ redundant_sum.remove_redundancy()
+ self.assertTrue(redundant_sum.a.a is redundant_sum.a.b)
+
+ def test_caching(self):
+
+ vals = [10, 20, 30, 40, 50]
+ f = lambda a, b, c, d, e : a + (b * c) - d ** e
+
+ # Set up our objects
+ Cs = [ch.Ch(v) for v in vals]
+ C_result = f(*Cs)
+
+ # Sometimes residuals should be cached
+ r1 = C_result.r
+ r2 = C_result.r
+ self.assertTrue(r1 is r2)
+
+ # Other times residuals need refreshing
+ Cs[0].set(x=5)
+ r3 = C_result.r
+ self.assertTrue(r3 is not r2)
+
+ # Sometimes derivatives should be cached
+ dr1 = C_result.dr_wrt(Cs[1])
+ dr2 = C_result.dr_wrt(Cs[1])
+ self.assertTrue(dr1 is dr2)
+
+ # Other times derivatives need refreshing
+ Cs[2].set(x=5)
+ dr3 = C_result.dr_wrt(Cs[1])
+ self.assertTrue(dr3 is not dr2)
+
+
+ def test_scalars(self):
+
+ try:
+ import theano.tensor as T
+ from theano import function
+ except:
+ return
+
+ # Set up variables and function
+ vals = [1, 2, 3, 4, 5]
+ f = lambda a, b, c, d, e : a + (b * c) - d ** e
+
+ # Set up our objects
+ Cs = [ch.Ch(v) for v in vals]
+ C_result = f(*Cs)
+
+ # Set up Theano's equivalents
+ Ts = T.dscalars('T1', 'T2', 'T3', 'T4', 'T5')
+ TF = f(*Ts)
+ T_result = function(Ts, TF)
+
+ # Make sure values and derivatives are equal
+ self.assertEqual(C_result.r, T_result(*vals))
+ for k in range(len(vals)):
+ theano_derivative = function(Ts, T.grad(TF, Ts[k]))(*vals)
+ #print C_result.dr_wrt(Cs[k])
+ our_derivative = C_result.dr_wrt(Cs[k])[0,0]
+ #print theano_derivative, our_derivative
+ self.assertEqual(theano_derivative, our_derivative)
+
+
+ def test_vectors(self):
+
+ try:
+ import theano.tensor as T
+ from theano import function
+ except:
+ return
+
+ for MT in [False, True]:
+
+ # Set up variables and function
+ vals = [np.random.randn(20) for i in range(5)]
+ f = lambda a, b, c, d, e : a + (b * c) - d ** e
+
+ # Set up our objects
+ Cs = [ch.Ch(v) for v in vals]
+ C_result = f(*Cs)
+ C_result.MT = MT
+
+ # Set up Theano equivalents
+ Ts = T.dvectors('T1', 'T2', 'T3', 'T4', 'T5')
+ TF = f(*Ts)
+ T_result = function(Ts, TF)
+
+ if False:
+ import theano.gradient
+ which = 1
+ theano_sse = (TF**2.).sum()
+ theano_grad = theano.gradient.grad(theano_sse, Ts[which])
+ theano_fn = function(Ts, theano_grad)
+ print(theano_fn(*vals))
+ C_result_grad = ch.SumOfSquares(C_result).dr_wrt(Cs[which])
+ print(C_result_grad)
+
+ # if True:
+ # aaa = np.linalg.solve(C_result_grad.T.dot(C_result_grad), C_result_grad.dot(np.zeros(C_result_grad.shape[1])))
+ # theano_hes = theano.R_obbb = theano.R_op()
+
+ import pdb; pdb.set_trace()
+
+ # Make sure values and derivatives are equal
+ np.testing.assert_array_equal(C_result.r, T_result(*vals))
+ for k in range(len(vals)):
+ theano_derivative = function(Ts, T.jacobian(TF, Ts[k]))(*vals)
+ our_derivative = np.array(C_result.dr_wrt(Cs[k]).todense())
+ #print theano_derivative, our_derivative
+
+ # Theano produces has more nans than we do during exponentiation.
+ # So we test only on entries where Theano is without NaN's
+ without_nans = np.nonzero(np.logical_not(np.isnan(theano_derivative.flatten())))[0]
+ np.testing.assert_array_equal(theano_derivative.flatten()[without_nans], our_derivative.flatten()[without_nans])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/chumpy-0.70/chumpy/test_inner_composition.py b/chumpy-0.70/chumpy/test_inner_composition.py
new file mode 100755
index 00000000..d14e4535
--- /dev/null
+++ b/chumpy-0.70/chumpy/test_inner_composition.py
@@ -0,0 +1,80 @@
+#!/usr/bin/env python
+# encoding: utf-8
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+import unittest
+from .ch import Ch, depends_on
+
+class TestInnerComposition(unittest.TestCase):
+
+ def test_ic(self):
+ child = Child(a=Ch(10))
+ parent = Parent(child=child, aliased=Ch(50))
+
+ junk = [parent.aliased_dependency for k in range(3)]
+ self.assertTrue(parent.dcount == 1)
+ self.assertTrue(parent.ocount == 0)
+ self.assertTrue(parent.rcount == 0)
+
+ junk = [parent.r for k in range(3)]
+ self.assertTrue(parent.dcount == 1)
+ self.assertTrue(parent.ocount == 1)
+ self.assertTrue(parent.rcount == 1)
+
+ parent.aliased = Ch(20)
+ junk = [parent.aliased_dependency for k in range(3)]
+ self.assertTrue(parent.dcount == 2)
+ self.assertTrue(parent.ocount == 1)
+ self.assertTrue(parent.rcount == 1)
+
+ junk = [parent.r for k in range(3)]
+ self.assertTrue(parent.dcount == 2)
+ self.assertTrue(parent.ocount == 2)
+ self.assertTrue(parent.rcount == 2)
+
+class Parent(Ch):
+ dterms = ('aliased', 'child')
+
+ def __init__(self, *args, **kwargs):
+ self.dcount = 0
+ self.ocount = 0
+ self.rcount = 0
+
+
+ def on_changed(self, which):
+ assert('aliased' in which and 'child' in which)
+ if 'aliased' in which:
+ self.ocount += 1
+
+ @depends_on('aliased')
+ def aliased_dependency(self):
+ self.dcount += 1
+
+ @property
+ def aliased(self):
+ return self.child.a
+
+ @aliased.setter
+ def aliased(self, val):
+ self.child.a = val
+
+ def compute_r(self):
+ self.rcount += 1
+ return 0
+
+ def compute_dr_wrt(self, wrt):
+ pass
+
+
+class Child(Ch):
+ dterms = ('a',)
+
+
+
+if __name__ == '__main__':
+ suite = unittest.TestLoader().loadTestsFromTestCase(TestInnerComposition)
+ unittest.TextTestRunner(verbosity=2).run(suite)
diff --git a/chumpy-0.70/chumpy/test_linalg.py b/chumpy-0.70/chumpy/test_linalg.py
new file mode 100755
index 00000000..ba0e9aa6
--- /dev/null
+++ b/chumpy-0.70/chumpy/test_linalg.py
@@ -0,0 +1,272 @@
+#!/usr/bin/env python
+# encoding: utf-8
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+import numpy as np
+import unittest
+
+from .ch import Ch
+
+
+
+
+class TestLinalg(unittest.TestCase):
+
+ def setUp(self):
+ np.random.seed(0)
+
+
+ def test_slogdet(self):
+ from . import ch
+ tmp = ch.random.randn(100).reshape((10,10))
+ # print 'chumpy version: ' + str(slogdet(tmp)[1].r)
+ # print 'old version:' + str(np.linalg.slogdet(tmp.r)[1])
+
+ eps = 1e-10
+ diff = np.random.rand(100) * eps
+ diff_reshaped = diff.reshape((10,10))
+ gt = np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1]
+ pred = ch.linalg.slogdet(tmp)[1].dr_wrt(tmp).dot(diff)
+ #print gt
+ #print pred
+ diff = gt - pred
+
+ self.assertTrue(np.max(np.abs(diff)) < 1e-12)
+
+ sgn_gt = np.linalg.slogdet(tmp.r)[0]
+ sgn_pred = ch.linalg.slogdet(tmp)[0]
+
+ #print sgn_gt
+ #print sgn_pred
+ diff = sgn_gt - sgn_pred.r
+ self.assertTrue(np.max(np.abs(diff)) < 1e-12)
+
+
+ def test_lstsq(self):
+ from .linalg import lstsq
+
+ shapes = ([10, 3], [3, 10])
+
+ for shape in shapes:
+ for b2d in True, False:
+ A = (np.random.rand(np.prod(shape))-.5).reshape(shape)
+ if b2d:
+ b = np.random.randn(shape[0],2)
+ else:
+ b = np.random.randn(shape[0])
+
+ x1, residuals1, rank1, s1 = lstsq(A, b)
+ x2, residuals2, rank2, s2 = np.linalg.lstsq(A, b)
+
+ #print x1.r
+ #print x2
+ #print residuals1.r
+ #print residuals2
+ self.assertTrue(np.max(np.abs(x1.r-x2)) < 1e-14)
+ if len(residuals2) > 0:
+ self.assertTrue(np.max(np.abs(residuals1.r-residuals2)) < 1e-14)
+
+
+
+
+ def test_pinv(self):
+ from .linalg import Pinv
+
+ data = (np.random.rand(12)-.5).reshape((3, 4))
+ pc_tall = Pinv(data)
+ pc_wide = Pinv(data.T)
+
+ pn_tall = np.linalg.pinv(data)
+ pn_wide = np.linalg.pinv(data.T)
+
+ tall_correct = np.max(np.abs(pc_tall.r - pn_tall)) < 1e-12
+ wide_correct = np.max(np.abs(pc_wide.r - pn_wide)) < 1e-12
+ # if not tall_correct or not wide_correct:
+ # print tall_correct
+ # print wide_correct
+ # import pdb; pdb.set_trace()
+ self.assertTrue(tall_correct)
+ self.assertTrue(wide_correct)
+
+ return # FIXME. how to test derivs?
+
+ for pc in [pc_tall, pc_wide]:
+
+ self.chkd(pc, pc.mtx)
+ import pdb; pdb.set_trace()
+
+
+
+ def test_svd(self):
+ from .linalg import Svd
+ eps = 1e-3
+ idx = 10
+
+ data = np.sin(np.arange(300)*100+10).reshape((-1,3))
+ data[3,:] = data[3,:]*0+10
+ data[:,1] *= 2
+ data[:,2] *= 4
+ data = data.copy()
+ u,s,v = np.linalg.svd(data, full_matrices=False)
+ data = Ch(data)
+ data2 = data.r.copy()
+ data2.ravel()[idx] += eps
+ u2,s2,v2 = np.linalg.svd(data2, full_matrices=False)
+
+
+ svdu, svdd, svdv = Svd(x=data)
+
+ # test singular values
+ diff_emp = (s2-s) / eps
+ diff_pred = svdd.dr_wrt(data)[:,idx]
+ #print diff_emp
+ #print diff_pred
+ ratio = diff_emp / diff_pred
+ #print ratio
+ self.assertTrue(np.max(np.abs(ratio - 1.)) < 1e-4)
+
+ # test V
+ diff_emp = (v2 - v) / eps
+ diff_pred = svdv.dr_wrt(data)[:,idx].reshape(diff_emp.shape)
+ ratio = diff_emp / diff_pred
+ #print ratio
+ self.assertTrue(np.max(np.abs(ratio - 1.)) < 1e-2)
+
+ # test U
+ diff_emp = (u2 - u) / eps
+ diff_pred = svdu.dr_wrt(data)[:,idx].reshape(diff_emp.shape)
+ ratio = diff_emp / diff_pred
+ #print ratio
+ self.assertTrue(np.max(np.abs(ratio - 1.)) < 1e-2)
+
+
+ def test_det(self):
+ from .linalg import Det
+
+ mtx1 = Ch(np.sin(2**np.arange(9)).reshape((3,3)))
+ mtx1_det = Det(mtx1)
+ dr = mtx1_det.dr_wrt(mtx1)
+
+ eps = 1e-5
+ mtx2 = mtx1.r.copy()
+ input_diff = np.sin(np.arange(mtx2.size)).reshape(mtx2.shape) * eps
+ mtx2 += input_diff
+ mtx2_det = Det(mtx2)
+
+ output_diff_emp = (np.linalg.det(mtx2) - np.linalg.det(mtx1.r)).ravel()
+
+ output_diff_pred = Det(mtx1).dr_wrt(mtx1).dot(input_diff.ravel())
+
+ #print output_diff_emp
+ #print output_diff_pred
+
+ self.assertTrue(np.max(np.abs(output_diff_emp - output_diff_pred)) < eps*1e-4)
+ self.assertTrue(np.max(np.abs(mtx1_det.r - np.linalg.det(mtx1.r)).ravel()) == 0)
+
+
+
+ def test_inv1(self):
+ from .linalg import Inv
+
+ mtx1 = Ch(np.sin(2**np.arange(9)).reshape((3,3)))
+ mtx1_inv = Inv(mtx1)
+ dr = mtx1_inv.dr_wrt(mtx1)
+
+ eps = 1e-5
+ mtx2 = mtx1.r.copy()
+ input_diff = np.sin(np.arange(mtx2.size)).reshape(mtx2.shape) * eps
+ mtx2 += input_diff
+ mtx2_inv = Inv(mtx2)
+
+ output_diff_emp = (np.linalg.inv(mtx2) - np.linalg.inv(mtx1.r)).ravel()
+ output_diff_pred = Inv(mtx1).dr_wrt(mtx1).dot(input_diff.ravel())
+
+ #print output_diff_emp
+ #print output_diff_pred
+
+ self.assertTrue(np.max(np.abs(output_diff_emp - output_diff_pred)) < eps*1e-4)
+ self.assertTrue(np.max(np.abs(mtx1_inv.r - np.linalg.inv(mtx1.r)).ravel()) == 0)
+
+ def test_inv2(self):
+ from .linalg import Inv
+
+ eps = 1e-8
+ idx = 13
+
+ mtx1 = np.random.rand(100).reshape((10,10))
+ mtx2 = mtx1.copy()
+ mtx2.ravel()[idx] += eps
+
+ diff_emp = (np.linalg.inv(mtx2) - np.linalg.inv(mtx1)) / eps
+
+ mtx1 = Ch(mtx1)
+ diff_pred = Inv(mtx1).dr_wrt(mtx1)[:,13].reshape(diff_emp.shape)
+ #print diff_emp
+ #print diff_pred
+ #print diff_emp - diff_pred
+ self.assertTrue(np.max(np.abs(diff_pred.ravel()-diff_emp.ravel())) < 1e-4)
+
+ @unittest.skipIf(np.__version__ < '1.8',
+ "broadcasting for matrix inverse not supported in numpy < 1.8")
+ def test_inv3(self):
+ """Test linalg.inv with broadcasting support."""
+
+ from .linalg import Inv
+
+ mtx1 = Ch(np.sin(2**np.arange(12)).reshape((3,2,2)))
+ mtx1_inv = Inv(mtx1)
+ dr = mtx1_inv.dr_wrt(mtx1)
+
+ eps = 1e-5
+ mtx2 = mtx1.r.copy()
+ input_diff = np.sin(np.arange(mtx2.size)).reshape(mtx2.shape) * eps
+ mtx2 += input_diff
+ mtx2_inv = Inv(mtx2)
+
+ output_diff_emp = (np.linalg.inv(mtx2) - np.linalg.inv(mtx1.r)).ravel()
+ output_diff_pred = Inv(mtx1).dr_wrt(mtx1).dot(input_diff.ravel())
+
+ # print output_diff_emp
+ # print output_diff_pred
+
+ self.assertTrue(np.max(np.abs(output_diff_emp.ravel() - output_diff_pred.ravel())) < eps*1e-3)
+ self.assertTrue(np.max(np.abs(mtx1_inv.r - np.linalg.inv(mtx1.r)).ravel()) == 0)
+
+ def chkd(self, obj, parm, eps=1e-14):
+ backed_up = parm.x
+
+ if True:
+ diff = (np.random.rand(parm.size)-.5).reshape(parm.shape)
+ else:
+ diff = np.zeros(parm.shape)
+ diff.ravel()[4] = 2.
+
+ dr = obj.dr_wrt(parm)
+
+ parm.x = backed_up - diff*eps
+ r_lower = obj.r
+
+ parm.x = backed_up + diff*eps
+ r_upper = obj.r
+
+ diff_emp = (r_upper - r_lower) / (eps*2.)
+ diff_pred = dr.dot(diff.ravel()).reshape(diff_emp.shape)
+
+ #print diff_emp
+ #print diff_pred
+ print(diff_emp / diff_pred)
+ print(diff_emp - diff_pred)
+
+ parm.x = backed_up
+
+
+
+suite = unittest.TestLoader().loadTestsFromTestCase(TestLinalg)
+
+if __name__ == '__main__':
+ unittest.main()
+
diff --git a/chumpy-0.70/chumpy/test_optimization.py b/chumpy-0.70/chumpy/test_optimization.py
new file mode 100755
index 00000000..ce3e0d03
--- /dev/null
+++ b/chumpy-0.70/chumpy/test_optimization.py
@@ -0,0 +1,204 @@
+#!/usr/bin/env python
+# encoding: utf-8
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+import time
+from numpy import *
+import unittest
+from . import ch
+from .optimization import minimize
+from .ch import Ch
+import numpy as np
+from scipy.optimize import rosen, rosen_der
+from .utils import row, col
+
+
+visualize = False
+
+
+def Rosen():
+
+ args = {
+ 'x1': Ch(-120.),
+ 'x2': Ch(-100.)
+ }
+ r1 = Ch(lambda x1, x2 : (x2 - x1**2.) * 10., args)
+ r2 = Ch(lambda x1 : x1 * -1. + 1, args)
+
+ func = [r1, r2]
+
+ return func, [args['x1'], args['x2']]
+
+class Madsen(Ch):
+ dterms = ('x',)
+ def compute_r(self):
+ x1 = self.x.r[0]
+ x2 = self.x.r[1]
+ result = np.array((
+ x1**2 + x2**2 + x1 * x2,
+ np.sin(x1),
+ np.cos(x2)
+ ))
+ return result
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is not self.x:
+ return None
+ jac = np.zeros((3,2))
+ x1 = self.x.r[0]
+ x2 = self.x.r[1]
+ jac[0,0] = 2. * x1 + x2
+ jac[0,1] = 2. * x2 + x1
+
+ jac[1,0] = np.cos(x1)
+ jac[1,1] = 0
+
+ jac[2,0] = 0
+ jac[2,1] = -np.sin(x2)
+
+ return jac
+
+
+ def set_and_get_r(self, x_in):
+ self.x = Ch(x_in)
+ return col(self.r)
+
+ def set_and_get_dr(self, x_in):
+ self.x = Ch(x_in)
+ return self.dr_wrt(self.x)
+
+
+
+
+class RosenCh(Ch):
+ dterms = ('x',)
+ def compute_r(self):
+
+ result = np.array((rosen(self.x.r) ))
+
+ return result
+
+ def set_and_get_r(self, x_in):
+ self.x = Ch(x_in)
+ return col(self.r)
+
+ def set_and_get_dr(self, x_in):
+ self.x = Ch(x_in)
+ return self.dr_wrt(self.x).flatten()
+
+
+ def compute_dr_wrt(self, wrt):
+ if wrt is self.x:
+ if visualize:
+ import matplotlib.pyplot as plt
+ residuals = np.sum(self.r**2)
+ print('------> RESIDUALS %.2e' % (residuals,))
+ print('------> CURRENT GUESS %s' % (str(self.x.r),))
+ plt.figure(123)
+
+ if not hasattr(self, 'vs'):
+ self.vs = []
+ self.xs = []
+ self.ys = []
+ self.vs.append(residuals)
+ self.xs.append(self.x.r[0])
+ self.ys.append(self.x.r[1])
+ plt.clf();
+ plt.subplot(1,2,1)
+ plt.plot(self.vs)
+ plt.subplot(1,2,2)
+ plt.plot(self.xs, self.ys)
+ plt.draw()
+
+
+ return row(rosen_der(self.x.r))
+
+
+
+class TestOptimization(unittest.TestCase):
+
+ def test_dogleg_rosen(self):
+ obj, freevars = Rosen()
+ minimize(fun=obj, x0=freevars, method='dogleg', options={'maxiter': 337, 'disp': False})
+ self.assertTrue(freevars[0].r[0]==1.)
+ self.assertTrue(freevars[1].r[0]==1.)
+
+ def test_dogleg_madsen(self):
+ obj = Madsen(x = Ch(np.array((3.,1.))))
+ minimize(fun=obj, x0=[obj.x], method='dogleg', options={'maxiter': 34, 'disp': False})
+ self.assertTrue(np.sum(obj.r**2)/2 < 0.386599528247)
+
+ @unittest.skip('negative sign in exponent screws with reverse mode')
+ def test_bfgs_rosen(self):
+ from .optimization import minimize_bfgs_lsq
+ obj, freevars = Rosen()
+ minimize_bfgs_lsq(obj=obj, niters=421, verbose=False, free_variables=freevars)
+ self.assertTrue(freevars[0].r[0]==1.)
+ self.assertTrue(freevars[1].r[0]==1.)
+
+ def test_bfgs_madsen(self):
+ from .ch import SumOfSquares
+ import scipy.optimize
+ obj = Ch(lambda x : SumOfSquares(Madsen(x = x)) )
+
+ def errfunc(x):
+ obj.x = Ch(x)
+ return obj.r
+
+ def gradfunc(x):
+ obj.x = Ch(x)
+ return obj.dr_wrt(obj.x).ravel()
+
+ x0 = np.array((3., 1.))
+
+ # Optimize with built-in bfgs.
+ # Note: with 8 iters, this actually requires 14 gradient evaluations.
+ # This can be verified by setting "disp" to 1.
+ #tm = time.time()
+ x1 = scipy.optimize.fmin_bfgs(errfunc, x0, fprime=gradfunc, maxiter=8, disp=0)
+ #print 'forward: took %.es' % (time.time() - tm,)
+ self.assertLess(obj.r/2., 0.4)
+
+ # Optimize with chumpy's minimize (which uses scipy's bfgs).
+ obj.x = x0
+ minimize(fun=obj, x0=[obj.x], method='bfgs', options={'maxiter': 8, 'disp': False})
+ self.assertLess(obj.r/2., 0.4)
+
+ def test_nested_select(self):
+ def beales(x, y):
+ e1 = 1.5 - x + x*y
+ e2 = 2.25 - x + x*(y**2)
+ e3 = 2.625 - x + x*(y**3)
+ return {'e1': e1, 'e2': e2, 'e3': e3}
+
+ x1 = ch.zeros(10)
+ y1 = ch.zeros(10)
+
+ # With a single select this worked
+ minimize(beales(x1, y1), x0=[x1[1:4], y1], method='dogleg', options={'disp': False})
+
+ x2 = ch.zeros(10)
+ y2 = ch.zeros(10)
+
+ # But this used to raise `AttributeError: 'Select' object has no attribute 'x'`
+ minimize(beales(x2, y2), x0=[x2[1:8][:3], y2], method='dogleg', options={'disp': False})
+ np.testing.assert_array_equal(x1, x2)
+ np.testing.assert_array_equal(y1, y2)
+
+
+suite = unittest.TestLoader().loadTestsFromTestCase(TestOptimization)
+
+if __name__ == '__main__':
+
+ if False: # show rosen
+ import matplotlib.pyplot as plt
+ visualize = True
+ plt.ion()
+ unittest.main()
+ import pdb; pdb.set_trace()
+ else:
+ unittest.main()
diff --git a/chumpy-0.70/chumpy/testing.py b/chumpy-0.70/chumpy/testing.py
new file mode 100644
index 00000000..e8bed9ef
--- /dev/null
+++ b/chumpy-0.70/chumpy/testing.py
@@ -0,0 +1,21 @@
+from . import ch
+import numpy as np
+
+fn1 = 'assert_allclose', 'assert_almost_equal', 'assert_approx_equal', 'assert_array_almost_equal', 'assert_array_almost_equal_nulp', 'assert_array_equal', 'assert_array_less', 'assert_array_max_ulp', 'assert_equal', 'assert_no_warnings', 'assert_string_equal'
+fn2 = 'assert_raises', 'assert_warns'
+
+# These are unhandled
+fn3 = 'build_err_msg', 'dec', 'decorate_methods', 'decorators', 'division', 'importall', 'jiffies', 'measure', 'memusage', 'nosetester', 'numpytest', 'print_assert_equal', 'print_function', 'raises', 'rand', 'run_module_suite', 'rundocs', 'runstring', 'test', 'utils', 'verbose'
+
+__all__ = fn1 + fn2
+
+for rtn in fn1:
+ exec('def %s(*args, **kwargs) : return np.testing.%s(np.asarray(args[0]), np.asarray(args[1]), *args[2:], **kwargs)' % (rtn, rtn))
+
+for rtn in fn2:
+ exec('def %s(*args, **kwargs) : return np.testing.%s(*args, **kwargs)' % (rtn, rtn))
+
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/chumpy-0.70/chumpy/utils.py b/chumpy-0.70/chumpy/utils.py
new file mode 100644
index 00000000..19287d3c
--- /dev/null
+++ b/chumpy-0.70/chumpy/utils.py
@@ -0,0 +1,93 @@
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+import scipy.sparse as sp
+import numpy as np
+
+def row(A):
+ return A.reshape((1, -1))
+
+
+def col(A):
+ return A.reshape((-1, 1))
+
+class timer(object):
+ def time(self):
+ import time
+ return time.time()
+ def __init__(self):
+ self._elapsed = 0
+ self._start = self.time()
+ def __call__(self):
+ if self._start is not None:
+ return self._elapsed + self.time() - self._start
+ else:
+ return self._elapsed
+ def pause(self):
+ assert self._start is not None
+ self._elapsed += self.time() - self._start
+ self._start = None
+ def resume(self):
+ assert self._start is None
+ self._start = self.time()
+
+def dfs_do_func_on_graph(node, func, *args, **kwargs):
+ '''
+ invoke func on each node of the dr graph
+ '''
+ for _node in node.tree_iterator():
+ func(_node, *args, **kwargs)
+
+
+def sparse_is_desireable(lhs, rhs):
+ '''
+ Examines a pair of matrices and determines if the result of their multiplication should be sparse or not.
+ '''
+ return False
+ if len(lhs.shape) == 1:
+ return False
+ else:
+ lhs_rows, lhs_cols = lhs.shape
+
+ if len(rhs.shape) == 1:
+ rhs_rows = 1
+ rhs_cols = rhs.size
+ else:
+ rhs_rows, rhs_cols = rhs.shape
+
+ result_size = lhs_rows * rhs_cols
+
+ if sp.issparse(lhs) and sp.issparse(rhs):
+ return True
+ elif sp.issparse(lhs):
+ lhs_zero_rows = lhs_rows - np.unique(lhs.nonzero()[0]).size
+ rhs_zero_cols = np.all(rhs==0, axis=0).sum()
+
+ elif sp.issparse(rhs):
+ lhs_zero_rows = np.all(lhs==0, axis=1).sum()
+ rhs_zero_cols = rhs_cols- np.unique(rhs.nonzero()[1]).size
+ else:
+ lhs_zero_rows = np.all(lhs==0, axis=1).sum()
+ rhs_zero_cols = np.all(rhs==0, axis=0).sum()
+
+ num_zeros = lhs_zero_rows * rhs_cols + rhs_zero_cols * lhs_rows - lhs_zero_rows * rhs_zero_cols
+
+ # A sparse matrix uses roughly 16 bytes per nonzero element (8 + 2 4-byte inds), while a dense matrix uses 8 bytes per element. So the break even point for sparsity is 50% nonzero. But in practice, it seems to be that the compression in a csc or csr matrix gets us break even at ~65% nonzero, which lets us say 50% is a conservative, worst cases cutoff.
+ return (float(num_zeros) / float(size)) >= 0.5
+
+
+def convert_inputs_to_sparse_if_necessary(lhs, rhs):
+ '''
+ This function checks to see if a sparse output is desireable given the inputs and if so, casts the inputs to sparse in order to make it so.
+ '''
+ if not sp.issparse(lhs) or not sp.issparse(rhs):
+ if sparse_is_desireable(lhs, rhs):
+ if not sp.issparse(lhs):
+ lhs = sp.csc_matrix(lhs)
+ #print "converting lhs into sparse matrix"
+ if not sp.issparse(rhs):
+ rhs = sp.csc_matrix(rhs)
+ #print "converting rhs into sparse matrix"
+ return lhs, rhs
diff --git a/chumpy-0.70/chumpy/version.py b/chumpy-0.70/chumpy/version.py
new file mode 100644
index 00000000..19db31a3
--- /dev/null
+++ b/chumpy-0.70/chumpy/version.py
@@ -0,0 +1,3 @@
+version = '0.70'
+short_version = version
+full_version = version
diff --git a/chumpy-0.70/requirements.txt b/chumpy-0.70/requirements.txt
new file mode 100644
index 00000000..dd4a9b71
--- /dev/null
+++ b/chumpy-0.70/requirements.txt
@@ -0,0 +1,3 @@
+numpy>=1.8.1
+scipy>=0.13.0
+six>=1.11.0
diff --git a/chumpy-0.70/setup.py b/chumpy-0.70/setup.py
new file mode 100644
index 00000000..89ded38b
--- /dev/null
+++ b/chumpy-0.70/setup.py
@@ -0,0 +1,59 @@
+"""
+Author(s): Matthew Loper
+
+See LICENCE.txt for licensing and contact information.
+"""
+
+from distutils.core import setup
+try: # for pip >= 10
+ from pip._internal.req import parse_requirements
+except ImportError: # for pip <= 9.0.3
+ from pip.req import parse_requirements
+from runpy import run_path
+
+install_reqs = parse_requirements('requirements.txt', session=False)
+try: # for pip < 20.1
+ install_requires = [str(ir.req) for ir in install_reqs]
+except AttributeError: # for pip >= 20.1
+ install_requires = [str(ir.requirement) for ir in install_reqs]
+
+def get_version():
+ namespace = run_path('chumpy/version.py')
+ return namespace['version']
+
+setup(name='chumpy',
+ version=get_version(),
+ packages = ['chumpy'],
+ author='Matthew Loper',
+ author_email='matt.loper@gmail.com',
+ url='https://github.com/mattloper/chumpy',
+ description='chumpy',
+ license='MIT',
+ install_requires=install_requires,
+
+ # See https://pypi.python.org/pypi?%3Aaction=list_classifiers
+ classifiers=[
+ # How mature is this project? Common values are
+ # 3 - Alpha
+ # 4 - Beta
+ # 5 - Production/Stable
+ 'Development Status :: 4 - Beta',
+
+ # Indicate who your project is intended for
+ 'Intended Audience :: Science/Research',
+ 'Topic :: Scientific/Engineering :: Mathematics',
+
+ # Pick your license as you wish (should match "license" above)
+ 'License :: OSI Approved :: MIT License',
+
+ # Specify the Python versions you support here. In particular, ensure
+ # that you indicate whether you support Python 2, Python 3 or both.
+ 'Programming Language :: Python :: 2',
+ 'Programming Language :: Python :: 2.7',
+ 'Programming Language :: Python :: 3',
+
+ 'Operating System :: MacOS :: MacOS X',
+ 'Operating System :: POSIX :: Linux'
+ ],
+)
+
diff --git a/download_weights.sh b/download_weights.sh
index affa5a9d..961b6a48 100644
--- a/download_weights.sh
+++ b/download_weights.sh
@@ -6,19 +6,17 @@ CheckpointsDir="models"
# Create necessary directories
mkdir -p models/musetalk models/musetalkV15 models/syncnet models/dwpose models/face-parse-bisent models/sd-vae models/whisper
-# Install required packages
-pip install -U "huggingface_hub[cli]"
-pip install gdown
-
# Set HuggingFace mirror endpoint
export HF_ENDPOINT=https://hf-mirror.com
+echo "🚀 Starting download using huggingface-cli..."
+
# Download MuseTalk V1.0 weights
huggingface-cli download TMElyralab/MuseTalk \
--local-dir $CheckpointsDir \
--include "musetalk/musetalk.json" "musetalk/pytorch_model.bin"
-# Download MuseTalk V1.5 weights (unet.pth)
+# Download MuseTalk V1.5 weights
huggingface-cli download TMElyralab/MuseTalk \
--local-dir $CheckpointsDir \
--include "musetalkV15/musetalk.json" "musetalkV15/unet.pth"
@@ -43,9 +41,12 @@ huggingface-cli download ByteDance/LatentSync \
--local-dir $CheckpointsDir/syncnet \
--include "latentsync_syncnet.pt"
-# Download Face Parse Bisent weights
-gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O $CheckpointsDir/face-parse-bisent/79999_iter.pth
+echo "🚀 Downloading supplemental weights via gdown and curl..."
+
+# Fix gdown usage (ID directly or use URL)
+gdown 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O $CheckpointsDir/face-parse-bisent/79999_iter.pth
+
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth \
-o $CheckpointsDir/face-parse-bisent/resnet18-5c106cde.pth
-echo "✅ All weights have been downloaded successfully!"
+echo "✅ All weights have been downloaded successfully!"
diff --git a/env-5070.md b/env-5070.md
new file mode 100644
index 00000000..d5554775
--- /dev/null
+++ b/env-5070.md
@@ -0,0 +1,52 @@
+lsb_release -a
+No LSB modules are available.
+Distributor ID: Ubuntu
+Description: Ubuntu 24.04.3 LTS
+Release: 24.04
+Codename: noble
+
+nvcc --version
+nvcc: NVIDIA (R) Cuda compiler driver
+Copyright (c) 2005-2025 NVIDIA Corporation
+Built on Fri_Feb_21_20:23:50_PST_2025
+Cuda compilation tools, release 12.8, V12.8.93
+Build cuda_12.8.r12.8/compiler.35583870_0
+nvidia-smi
+Wed May 6 11:19:02 2026
++-----------------------------------------------------------------------------------------+
+| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 |
+|-----------------------------------------+------------------------+----------------------+
+| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
+| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
+| | | MIG M. |
+|=========================================+========================+======================|
+| 0 NVIDIA GeForce RTX 5070 Off | 00000000:01:00.0 Off | N/A |
+| 0% 33C P8 3W / 250W | 309MiB / 12227MiB | 0% Default |
+| | | N/A |
++-----------------------------------------+------------------------+----------------------+
+
++-----------------------------------------------------------------------------------------+
+| Processes: |
+| GPU GI CI PID Type Process name GPU Memory |
+| ID ID Usage |
+|=========================================================================================|
+| 0 N/A N/A 3021 G /usr/lib/xorg/Xorg 74MiB |
+| 0 N/A N/A 3174 C+G ...c/gnome-remote-desktop-daemon 162MiB |
+| 0 N/A N/A 3237 G /usr/bin/gnome-shell 7MiB |
+| 0 N/A N/A 4028 G /usr/bin/gnome-control-center 9MiB |
++-----------------------------------------------------------------------------------------+
+
+ffmpeg -version
+ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
+built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
+configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --enable-openal --enable-opencl --enable-opengl --disable-sndio --enable-libvpl --disable-libmfx --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-chromaprint --enable-frei0r --enable-ladspa --enable-libbluray --enable-libjack --enable-libpulse --enable-librabbitmq --enable-librist --enable-libsrt --enable-libssh --enable-libsvtav1 --enable-libx264 --enable-libzmq --enable-libzvbi --enable-lv2 --enable-sdl2 --enable-libplacebo --enable-librav1e --enable-pocketsphinx --enable-librsvg --enable-libjxl --enable-shared
+libavutil 58. 29.100 / 58. 29.100
+libavcodec 60. 31.102 / 60. 31.102
+libavformat 60. 16.100 / 60. 16.100
+libavdevice 60. 3.100 / 60. 3.100
+libavfilter 9. 12.100 / 9. 12.100
+libswscale 7. 5.100 / 7. 5.100
+libswresample 4. 12.100 / 4. 12.100
+libpostproc 57. 3.100 / 57. 3.100
+ldconfig -p | grep libsndfile
+libsndfile.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libsndfile.so.1
diff --git a/musetalk/__init__.py b/musetalk/__init__.py
new file mode 100644
index 00000000..7d34b07a
--- /dev/null
+++ b/musetalk/__init__.py
@@ -0,0 +1 @@
+# MuseTalk package
diff --git a/musetalk/loss/vgg_face.py b/musetalk/loss/vgg_face.py
index b41faadf..7915a08d 100755
--- a/musetalk/loss/vgg_face.py
+++ b/musetalk/loss/vgg_face.py
@@ -132,7 +132,7 @@ class Vgg19(torch.nn.Module):
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
- vgg_pretrained_features = models.vgg19(pretrained=True).features
+ vgg_pretrained_features = models.vgg19(weights='DEFAULT').features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
diff --git a/musetalk/models/unet.py b/musetalk/models/unet.py
index 575e79af..7c2edb05 100755
--- a/musetalk/models/unet.py
+++ b/musetalk/models/unet.py
@@ -41,7 +41,7 @@ def __init__(self,
self.device = device
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
+ weights = torch.load(model_path, weights_only=False) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(weights)
if use_float16:
self.model = self.model.half()
diff --git a/musetalk/utils/audio_processor.py b/musetalk/utils/audio_processor.py
index dbdee253..cb845874 100755
--- a/musetalk/utils/audio_processor.py
+++ b/musetalk/utils/audio_processor.py
@@ -58,7 +58,7 @@ def get_whisper_chunk(
# Trim the last segment to remove padding
sr = 16000
audio_fps = 50
- fps = int(fps)
+ fps = float(fps)
whisper_idx_multiplier = audio_fps / fps
num_frames = math.floor((librosa_length / sr) * fps)
actual_length = math.floor((librosa_length / sr) * audio_fps)
diff --git a/musetalk/utils/face_detection/api.py b/musetalk/utils/face_detection/api.py
index 0a6a8d66..9c1573b8 100755
--- a/musetalk/utils/face_detection/api.py
+++ b/musetalk/utils/face_detection/api.py
@@ -63,7 +63,8 @@ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
# Get the face detector
- face_detector_module = __import__('face_detection.detection.' + face_detector,
+ package_name = __name__.rsplit('.', 1)[0]
+ face_detector_module = __import__(package_name + '.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
diff --git a/musetalk/utils/face_detection/detection/sfd/sfd_detector.py b/musetalk/utils/face_detection/detection/sfd/sfd_detector.py
index 8fbce152..5a92c8f3 100755
--- a/musetalk/utils/face_detection/detection/sfd/sfd_detector.py
+++ b/musetalk/utils/face_detection/detection/sfd/sfd_detector.py
@@ -21,7 +21,7 @@ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path
if not os.path.isfile(path_to_detector):
model_weights = load_url(models_urls['s3fd'])
else:
- model_weights = torch.load(path_to_detector)
+ model_weights = torch.load(path_to_detector, weights_only=False)
self.face_detector = s3fd()
self.face_detector.load_state_dict(model_weights)
diff --git a/musetalk/utils/face_parsing/__init__.py b/musetalk/utils/face_parsing/__init__.py
index 09c1c02a..7801a766 100755
--- a/musetalk/utils/face_parsing/__init__.py
+++ b/musetalk/utils/face_parsing/__init__.py
@@ -62,9 +62,9 @@ def model_init(self,
net = BiSeNet(resnet_path)
if torch.cuda.is_available():
net.cuda()
- net.load_state_dict(torch.load(model_pth))
+ net.load_state_dict(torch.load(model_pth, weights_only=False))
else:
- net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
+ net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu'), weights_only=False))
net.eval()
return net
diff --git a/musetalk/utils/face_parsing/resnet.py b/musetalk/utils/face_parsing/resnet.py
index e2e5d87e..a306abb7 100755
--- a/musetalk/utils/face_parsing/resnet.py
+++ b/musetalk/utils/face_parsing/resnet.py
@@ -80,7 +80,7 @@ def forward(self, x):
return feat8, feat16, feat32
def init_weight(self, model_path):
- state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
+ state_dict = torch.load(model_path, weights_only=False) #modelzoo.load_url(resnet18_url)
self_state_dict = self.state_dict()
for k, v in state_dict.items():
if 'fc' in k: continue
diff --git a/musetalk/utils/preprocessing.py b/musetalk/utils/preprocessing.py
index 978480c6..e5d7ddea 100755
--- a/musetalk/utils/preprocessing.py
+++ b/musetalk/utils/preprocessing.py
@@ -1,5 +1,5 @@
import sys
-from face_detection import FaceAlignment,LandmarksType
+from .face_detection import FaceAlignment,LandmarksType
from os import listdir, path
import subprocess
import numpy as np
diff --git a/musetalk/utils/training_utils.py b/musetalk/utils/training_utils.py
index 010f01f1..e2d4f429 100644
--- a/musetalk/utils/training_utils.py
+++ b/musetalk/utils/training_utils.py
@@ -75,7 +75,7 @@ def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
if not cfg.random_init_unet:
pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
- checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
+ checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device, weights_only=False)
model_dict['unet'].load_state_dict(checkpoint)
unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
@@ -261,7 +261,8 @@ def initialize_syncnet(cfg, accelerator, weight_dtype):
print(
f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
checkpoint = torch.load(
- syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
+ os.path.join(vae_path, 'diffusion_pytorch_model.bin'), map_location=accelerator.device, weights_only=False
+ )
syncnet.load_state_dict(checkpoint["state_dict"])
syncnet.to(dtype=weight_dtype)
syncnet.requires_grad_(False)
diff --git a/musetalk/whisper/whisper/__init__.py b/musetalk/whisper/whisper/__init__.py
index b9255534..b47a6d38 100755
--- a/musetalk/whisper/whisper/__init__.py
+++ b/musetalk/whisper/whisper/__init__.py
@@ -106,7 +106,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
- checkpoint = torch.load(fp, map_location=device)
+ checkpoint = torch.load(fp, map_location=device, weights_only=False)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
diff --git a/requirements-rtx5070.txt b/requirements-rtx5070.txt
new file mode 100644
index 00000000..228479ce
--- /dev/null
+++ b/requirements-rtx5070.txt
@@ -0,0 +1,47 @@
+# Tailored requirements for Ubuntu 24.04 + CUDA 12.8 + RTX 5070
+# Updated based on expert feedback for stability
+#
+# ⚠️ INSTALLATION INSTRUCTIONS (Option B: Native Blackwell Support):
+# 1. Install PyTorch 2.7.0 (cu128) natively supporting sm_120:
+# pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 --force-reinstall
+# 2. Pre-install build tools and numpy:
+# pip install --upgrade pip setuptools wheel numpy
+# 3. Install chumpy manually without build isolation:
+# pip install chumpy --no-build-isolation
+# 4. Compile mmcv from source (must disable build isolation):
+# pip install -U openmim ninja Cython
+# NVCC_APPEND_FLAGS="-allow-unsupported-compiler" pip install "mmcv>=2.1.0" --no-build-isolation --no-cache-dir
+# 5. Install this file:
+# pip install -r requirements-rtx5070.txt
+# Core AI Frameworks - Versions Locked
+diffusers==0.30.2
+transformers==4.39.2
+accelerate==0.28.0
+huggingface_hub==0.36.2 # Golden version: satisfies both transformers <1.0 and gradio requirements
+einops>=0.8.1
+omegaconf
+
+# OpenMMLab Ecosystem (required by musetalk/utils/preprocessing.py)
+# mmcv needs source compilation for CUDA 12.8, install separately (see instructions at top)
+# NOTE: MMLab packages below are installed via `mim` in Step D. Commented out to avoid pip resolver conflicts with mmcv version.
+# mmengine>=0.10.0
+# mmdet>=3.1.0 # <--- Added: Required by mmpose
+chumpy # <--- Added: Required by mmpose
+# mmpose>=1.3.0
+
+# Media Processing - IMPORTANT: moviepy locked to 1.x
+moviepy==1.0.3
+opencv-python>=4.9.0.80,<4.11.0
+soundfile>=0.12.1
+librosa>=0.11.0
+ffmpeg-python
+imageio[ffmpeg]
+
+# Interface & Utilities
+gradio==5.24.0 # Locked to official tested version
+gradio-client==1.8.0 # Prevent conflict with hf-gradio/gradio 6.x
+pillow<10.0.0 # Prevent moviepy 1.0.3 crash (Image.ANTIALIAS removed in Pillow 10)
+requests
+gdown
+numpy<2.0.0
+setuptools # Required by mmcv build (pkg_resources)
diff --git a/reset_mmdet.py b/reset_mmdet.py
new file mode 100644
index 00000000..6200519e
--- /dev/null
+++ b/reset_mmdet.py
@@ -0,0 +1,28 @@
+import os
+
+def reset_mmdet_init():
+ # 明确指定 site-packages 路径
+ target_file = "/root/miniconda3/envs/musetalk/lib/python3.10/site-packages/mmdet/__init__.py"
+
+ if not os.path.exists(target_file):
+ print(f"Error: 找不到文件 {target_file}")
+ return
+
+ # 定义极其精简且无害的 __init__.py 内容
+ # 移除了所有版本检查逻辑
+ minimal_content = """__version__ = '3.3.0'
+
+def digit_version(v_str):
+ return [int(x) for x in v_str.split('.') if x.isdigit()]
+"""
+
+ try:
+ with open(target_file, 'w', encoding='utf-8') as f:
+ f.write(minimal_content)
+ print(f"✅ 成功重写: {target_file}")
+ print("🚀 已彻底移除 MMCV 版本检查。")
+ except Exception as e:
+ print(f"❌ 重写失败: {e}")
+
+if __name__ == "__main__":
+ reset_mmdet_init()
diff --git a/run_cli.py b/run_cli.py
new file mode 100644
index 00000000..ed07856d
--- /dev/null
+++ b/run_cli.py
@@ -0,0 +1,136 @@
+import argparse
+import os
+import sys
+import yaml
+import subprocess
+
+def validate_inputs(face_path, audio_path):
+ # Check if face_path exists and is non-empty
+ if not os.path.exists(face_path):
+ print(f"❌ 错误: 输入的 Face 文件/目录不存在: {face_path}")
+ return False
+ if not os.path.isdir(face_path) and os.path.getsize(face_path) == 0:
+ print(f"❌ 错误: 输入的 Face 文件为空 (0 字节): {face_path}")
+ return False
+
+ # Check if audio_path exists and is non-empty
+ if not os.path.exists(audio_path):
+ print(f"❌ 错误: 输入的 Audio 文件不存在: {audio_path}")
+ return False
+ if os.path.getsize(audio_path) == 0:
+ print(f"❌ 错误: 输入的 Audio 文件为空 (0 字节): {audio_path}")
+ return False
+
+ # Check face path validity (video or image)
+ _, ext = os.path.splitext(face_path)
+ ext_lower = ext.lower()
+ if ext_lower in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
+ # Try to open with OpenCV
+ try:
+ import cv2
+ cap = cv2.VideoCapture(face_path)
+ if not cap.isOpened():
+ print(f"❌ 错误: 无法打开视频文件 {face_path},文件可能损坏或写入不完整(如 moov atom 缺失)。")
+ cap.release()
+ return False
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ cap.release()
+ if fps <= 0:
+ print(f"❌ 错误: 视频文件 {face_path} 的帧率无效({fps}),文件可能损坏或写入不完整。")
+ return False
+ except ImportError:
+ # If cv2 is not installed (though it should be), we fall back to a basic check
+ pass
+ elif ext_lower in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
+ try:
+ import cv2
+ img = cv2.imread(face_path)
+ if img is None or img.size == 0:
+ print(f"❌ 错误: 无法读取图像文件 {face_path},文件可能损坏或不合法。")
+ return False
+ except ImportError:
+ pass
+ elif os.path.isdir(face_path):
+ # Check if directory contains images
+ import glob
+ img_list = glob.glob(os.path.join(face_path, '*.[jpJP][pnPN]*[gG]'))
+ if not img_list:
+ print(f"❌ 错误: 输入目录中未找到任何图片: {face_path}")
+ return False
+ else:
+ print(f"❌ 错误: 不支持的 Face 输入类型或后缀 {face_path}。")
+ return False
+
+ return True
+
+def main():
+ parser = argparse.ArgumentParser(description="MuseTalk CLI Wrapper")
+ parser.add_argument("--face", type=str, required=True, help="Input face video path")
+ parser.add_argument("--audio", type=str, required=True, help="Input audio path")
+ parser.add_argument("--outfile", type=str, required=True, help="Output video path (must be a specific .mp4 file)")
+ parser.add_argument("--version", type=str, default="v1.5", choices=["v1.0", "v1.5"])
+
+ args = parser.parse_args()
+
+ # 将输入的路径转化为绝对路径,确保 MuseTalk 底层能够正确识别
+ face_path = os.path.abspath(args.face)
+ audio_path = os.path.abspath(args.audio)
+ outfile_path = os.path.abspath(args.outfile)
+
+ # 验证输入文件的有效性
+ if not validate_inputs(face_path, audio_path):
+ sys.exit(1)
+
+ # MuseTalk 默认是通过 yaml 配置文件来执行批量任务的,为了支持单次 CLI 调用,这里动态生成一个临时的 yaml
+ config_dict = {
+ "task_0": {
+ "video_path": face_path,
+ "audio_path": audio_path
+ }
+ }
+
+ temp_yaml = os.path.abspath("configs/inference/temp_cli_task.yaml")
+ os.makedirs(os.path.dirname(temp_yaml), exist_ok=True)
+ with open(temp_yaml, "w", encoding="utf-8") as f:
+ yaml.dump(config_dict, f)
+
+ # 提取输出目录
+ out_dir = os.path.dirname(outfile_path)
+ os.makedirs(out_dir, exist_ok=True)
+
+ # 根据选定的版本,自动组装底层的模型路径
+ if args.version == "v1.0":
+ unet_model_path = "./models/musetalk/pytorch_model.bin"
+ unet_config = "./models/musetalk/musetalk.json"
+ version_arg = "v1"
+ else:
+ unet_model_path = "./models/musetalkV15/unet.pth"
+ unet_config = "./models/musetalkV15/musetalk.json"
+ version_arg = "v15"
+
+ # 组装最终的底层调用命令
+ cmd = [
+ sys.executable, "-m", "scripts.inference",
+ "--inference_config", temp_yaml,
+ "--result_dir", out_dir,
+ "--output_vid_name", outfile_path, # 传入绝对路径,底层代码 os.path.join 遇到绝对路径会自动使用绝对路径
+ "--unet_model_path", unet_model_path,
+ "--unet_config", unet_config,
+ "--version", version_arg
+ ]
+
+ print(f"🚀 开始执行数字人合成任务...\n底层调用命令: {' '.join(cmd)}\n")
+
+ try:
+ subprocess.run(cmd, check=True)
+ print(f"\n✅ 任务圆满完成!\n视频已保存至: {outfile_path}")
+ except subprocess.CalledProcessError as e:
+ print(f"\n❌ 推理过程中发生错误。底层推理命令执行失败,请检查上方日志。")
+ sys.exit(1)
+ finally:
+ # 清理临时的 yaml 配置文件
+ if os.path.exists(temp_yaml):
+ os.remove(temp_yaml)
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/inference.py b/scripts/inference.py
index 428afb99..72988f22 100644
--- a/scripts/inference.py
+++ b/scripts/inference.py
@@ -3,6 +3,16 @@
import math
import copy
import torch
+
+# --- PyTorch 2.6+ Compatibility Monkey Patch ---
+_original_load = torch.load
+def _patched_load(*args, **kwargs):
+ if 'weights_only' not in kwargs:
+ kwargs['weights_only'] = False
+ return _original_load(*args, **kwargs)
+torch.load = _patched_load
+# ---------------------------------------------
+
import glob
import shutil
import pickle
@@ -81,6 +91,7 @@ def main(args):
print("Loaded inference config:", inference_config)
# Process each task
+ has_error = False
for task_id in inference_config:
try:
# Get task configuration
@@ -89,6 +100,16 @@ def main(args):
if "result_name" in inference_config[task_id]:
args.output_vid_name = inference_config[task_id]["result_name"]
+ # Pre-validate files
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f"Input face video/image/directory path does not exist: {video_path}")
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"Input audio path does not exist: {audio_path}")
+ if not os.path.isdir(video_path) and os.path.getsize(video_path) == 0:
+ raise ValueError(f"Input face video/image file is empty (0 bytes): {video_path}")
+ if os.path.getsize(audio_path) == 0:
+ raise ValueError(f"Input audio file is empty (0 bytes): {audio_path}")
+
# Set bbox_shift based on version
if args.version == "v15":
bbox_shift = 0 # v15 uses fixed bbox_shift
@@ -121,9 +142,15 @@ def main(args):
save_dir_full = os.path.join(temp_dir, input_basename)
os.makedirs(save_dir_full, exist_ok=True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
- os.system(cmd)
+ ret_code = os.system(cmd)
+ if ret_code != 0:
+ raise RuntimeError(f"ffmpeg frame extraction failed with status code {ret_code}.")
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
+ if not input_img_list:
+ raise ValueError(f"No frames extracted from {video_path}. The video may be corrupted (e.g., moov atom missing).")
fps = get_video_fps(video_path)
+ if fps <= 0:
+ raise ValueError(f"Invalid fps detected for video {video_path}: {fps}")
elif get_file_type(video_path) == "image":
input_img_list = [video_path]
fps = args.fps
@@ -131,11 +158,15 @@ def main(args):
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
+ if not input_img_list:
+ raise ValueError(f"No images found in directory: {video_path}")
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
+ if whisper_input_features is None:
+ raise ValueError(f"Failed to process audio or extract features from {audio_path}")
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
@@ -247,6 +278,11 @@ def main(args):
print(f"Results saved to {output_vid_name}")
except Exception as e:
print("Error occurred during processing:", e)
+ has_error = True
+
+ if has_error:
+ print("\n❌ Error: One or more tasks failed during processing.")
+ sys.exit(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -260,7 +296,7 @@ def main(args):
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
- parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
+ parser.add_argument("--fps", type=float, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
diff --git a/scripts/realtime_inference.py b/scripts/realtime_inference.py
index 579b050f..d22c7fb1 100644
--- a/scripts/realtime_inference.py
+++ b/scripts/realtime_inference.py
@@ -4,6 +4,16 @@
import numpy as np
import cv2
import torch
+
+# --- PyTorch 2.6+ Compatibility Monkey Patch ---
+_original_load = torch.load
+def _patched_load(*args, **kwargs):
+ if 'weights_only' not in kwargs:
+ kwargs['weights_only'] = False
+ return _original_load(*args, **kwargs)
+torch.load = _patched_load
+# ---------------------------------------------
+
import glob
import pickle
import sys
@@ -16,7 +26,7 @@
from musetalk.utils.utils import datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs
from musetalk.utils.blending import get_image_prepare_material, get_image_blending
-from musetalk.utils.utils import load_all_model
+from musetalk.utils.utils import load_all_model, get_video_fps, get_file_type
from musetalk.utils.audio_processor import AudioProcessor
import shutil
@@ -96,7 +106,7 @@ def init(self):
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
self.prepare_material()
else:
- self.input_latent_list_cycle = torch.load(self.latents_out_path)
+ self.input_latent_list_cycle = torch.load(self.latents_out_path, weights_only=False)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
@@ -133,7 +143,7 @@ def init(self):
else:
sys.exit()
else:
- self.input_latent_list_cycle = torch.load(self.latents_out_path)
+ self.input_latent_list_cycle = torch.load(self.latents_out_path, weights_only=False)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
@@ -325,7 +335,7 @@ def inference(self, audio_path, out_vid_name, fps, skip_save_images):
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
- parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
+ parser.add_argument("--fps", type=float, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=20, help="Batch size for inference")
@@ -400,10 +410,15 @@ def inference(self, audio_path, out_vid_name, fps, skip_save_images):
batch_size=args.batch_size,
preparation=data_preparation)
+ if get_file_type(video_path) == "video":
+ fps = get_video_fps(video_path)
+ else:
+ fps = args.fps
+
audio_clips = inference_config[avatar_id]["audio_clips"]
for audio_num, audio_path in audio_clips.items():
- print("Inferring using:", audio_path)
+ print(f"Inferring using: {audio_path} at {fps} FPS")
avatar.inference(audio_path,
audio_num,
- args.fps,
+ fps,
args.skip_save_images)