1
+ import numpy as np
2
+ import scipy as sp
3
+ from numba import njit
4
+ from functools import lru_cache
5
+
6
+
7
+ @lru_cache (maxsize = 2 )
8
+ def create_convolution_kernels (ndim ):
9
+ """Create convolution kernels for advective flux computation.
10
+
11
+ Parameters
12
+ ----------
13
+ ndim : int
14
+ Number of dimensions (2 or 3)
15
+
16
+ Returns
17
+ -------
18
+ dict
19
+ Dictionary containing kernels for each direction
20
+
21
+ Notes
22
+ -----
23
+ Results are cached since kernels don't change during computation.
24
+ """
25
+ if ndim == 2 :
26
+ kernel_u = np .array ([[0.5 , 1.0 , 0.5 ], [0.5 , 1.0 , 0.5 ]])
27
+ return {
28
+ 'u' : kernel_u ,
29
+ 'v' : kernel_u .T
30
+ }
31
+ elif ndim == 3 :
32
+ kernel_u = np .array ([
33
+ [[1 , 2 , 1 ], [2 , 4 , 2 ], [1 , 2 , 1 ]],
34
+ [[1 , 2 , 1 ], [2 , 4 , 2 ], [1 , 2 , 1 ]]
35
+ ])
36
+ return {
37
+ 'u' : kernel_u ,
38
+ 'v' : np .swapaxes (kernel_u , 1 , 0 ),
39
+ 'w' : np .swapaxes (kernel_u , 2 , 0 )
40
+ }
41
+ else :
42
+ raise ValueError (f"Unsupported dimension: { ndim } " )
43
+
44
+
45
+ @njit
46
+ def _numba_convolve_2d (data , kernel ):
47
+ """Numba-compiled 2D convolution for better performance."""
48
+ data_h , data_w = data .shape
49
+ kernel_h , kernel_w = kernel .shape
50
+
51
+ result_h = data_h - kernel_h + 1
52
+ result_w = data_w - kernel_w + 1
53
+ result = np .zeros ((result_h , result_w ))
54
+
55
+ for i in range (result_h ):
56
+ for j in range (result_w ):
57
+ for ki in range (kernel_h ):
58
+ for kj in range (kernel_w ):
59
+ result [i , j ] += data [i + ki , j + kj ] * kernel [ki , kj ]
60
+
61
+ return result
62
+
63
+
64
+ @njit
65
+ def _numba_convolve_3d (data , kernel ):
66
+ """Numba-compiled 3D convolution for better performance."""
67
+ data_d , data_h , data_w = data .shape
68
+ kernel_d , kernel_h , kernel_w = kernel .shape
69
+
70
+ result_d = data_d - kernel_d + 1
71
+ result_h = data_h - kernel_h + 1
72
+ result_w = data_w - kernel_w + 1
73
+ result = np .zeros ((result_d , result_h , result_w ))
74
+
75
+ for i in range (result_d ):
76
+ for j in range (result_h ):
77
+ for k in range (result_w ):
78
+ for ki in range (kernel_d ):
79
+ for kj in range (kernel_h ):
80
+ for kk in range (kernel_w ):
81
+ result [i , j , k ] += data [i + ki , j + kj , k + kk ] * kernel [ki , kj , kk ]
82
+
83
+ return result
84
+
85
+
86
+ def apply_convolution_kernel (data , kernel , normalize = True , axis_swap = None , use_numba = True ):
87
+ """Apply convolution kernel with optional normalization and axis swapping.
88
+
89
+ Parameters
90
+ ----------
91
+ data : np.ndarray
92
+ Input data array
93
+ kernel : np.ndarray
94
+ Convolution kernel
95
+ normalize : bool, default=True
96
+ Whether to normalize by kernel sum
97
+ axis_swap : tuple or None, default=None
98
+ Tuple of (from_axis, to_axis) for np.moveaxis
99
+ use_numba : bool, default=True
100
+ Whether to use Numba-compiled convolution (faster for repeated calls)
101
+
102
+ Returns
103
+ -------
104
+ np.ndarray
105
+ Convolved result
106
+
107
+ Notes
108
+ -----
109
+ For large arrays or single calls, scipy.signal.fftconvolve might be faster.
110
+ For repeated calls on smaller arrays, Numba convolution is typically faster.
111
+ """
112
+ if use_numba and data .ndim in (2 , 3 ):
113
+ if data .ndim == 2 :
114
+ result = _numba_convolve_2d (data , kernel )
115
+ else : # 3D
116
+ result = _numba_convolve_3d (data , kernel )
117
+ else :
118
+ # Fallback to scipy for other dimensions or when requested
119
+ result = sp .signal .fftconvolve (data , kernel , mode = "valid" )
120
+
121
+ if normalize :
122
+ result = result / kernel .sum ()
123
+
124
+ if axis_swap is not None :
125
+ result = np .moveaxis (result , axis_swap [0 ], axis_swap [1 ])
126
+
127
+ return result
128
+
129
+
130
+ # Convenience functions for specific operations
131
+ def apply_u_kernel_convolution (data , kernel_u , normalize = True ):
132
+ """Apply u-direction kernel with standard axis swap."""
133
+ return apply_convolution_kernel (data , kernel_u , normalize = normalize , axis_swap = (0 , - 1 ))
134
+
135
+
136
+ def apply_v_kernel_convolution_3d (data , kernel_v , normalize = True ):
137
+ """Apply v-direction kernel for 3D with standard axis swap."""
138
+ return apply_convolution_kernel (data , kernel_v , normalize = normalize , axis_swap = (- 1 , 0 ))
0 commit comments