@@ -171,7 +171,7 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea
171171 function totalAssets () public view virtual override returns (uint256 ) {
172172 IERC4626 _subVault = subVault;
173173 if (address (_subVault) == address (0 )) {
174- return super . totalAssets ( );
174+ return IERC20 ( asset ()). balanceOf ( address ( this ) );
175175 }
176176 return _subVault.convertToAssets (_subVault.balanceOf (address (this )));
177177 }
@@ -204,23 +204,21 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea
204204 */
205205 function _convertToShares (uint256 assets , MathUpgradeable.Rounding rounding ) internal view virtual override returns (uint256 shares ) {
206206 IERC4626 _subVault = subVault;
207- uint256 _totalAssets = totalAssets ();
208207 uint256 _totalPrincipal = totalPrincipal;
209208 uint256 _totalSupply = totalSupply ();
210209
211210 if (address (_subVault) == address (0 )) {
212- uint256 effectiveTotalAssets = enablePerformanceFee ? _min (_totalAssets , _totalPrincipal) : _totalAssets ;
211+ uint256 effectiveTotalAssets = enablePerformanceFee ? _min (totalAssets () , _totalPrincipal) : totalAssets () ;
213212 return _totalSupply.mulDiv (assets, effectiveTotalAssets, rounding);
214213 }
215214
216- uint256 subShares = _convertToSubShares (_subVault, assets, rounding);
215+ uint256 subShares = _assetsToSubVaultShares (_subVault, assets, rounding);
217216 uint256 totalSubShares = _subVault.balanceOf (address (this ));
218- uint256 profit = _totalAssets > _totalPrincipal ? _totalAssets - _totalPrincipal : 0 ;
219217
220218 if (enablePerformanceFee) {
221- // subSharesFee is the amount of subVault shares set aside as performance fee
222- uint256 subSharesFee = _convertToSubShares (_subVault, profit, rounding);
223- totalSubShares -= subSharesFee ;
219+ // since we use totalSubShares in the denominator of the final calculation,
220+ // and we are subtracting profit from it, we should use the same rounding direction for profit
221+ totalSubShares -= totalProfitInSubVaultShares ( _flipRounding (rounding)) ;
224222 }
225223
226224 return _totalSupply.mulDiv (subShares, totalSubShares, rounding);
@@ -231,46 +229,68 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea
231229 */
232230 function _convertToAssets (uint256 shares , MathUpgradeable.Rounding rounding ) internal view virtual override returns (uint256 assets ) {
233231 IERC4626 _subVault = subVault;
234- uint256 _totalAssets = totalAssets ();
235232 uint256 _totalPrincipal = totalPrincipal;
236233 uint256 _totalSupply = totalSupply ();
237234
235+ // if we have no subvault, we just do normal pro-rata calculation
238236 if (address (_subVault) == address (0 )) {
239- uint256 effectiveTotalAssets = enablePerformanceFee ? _min (_totalAssets , _totalPrincipal) : _totalAssets ;
237+ uint256 effectiveTotalAssets = enablePerformanceFee ? _min (totalAssets () , _totalPrincipal) : totalAssets () ;
240238 return effectiveTotalAssets.mulDiv (shares, _totalSupply, rounding);
241239 }
242-
240+
243241 uint256 totalSubShares = _subVault.balanceOf (address (this ));
244- uint256 profit = _totalAssets > _totalPrincipal ? _totalAssets - _totalPrincipal : 0 ;
245242
246- if (profit > 0 && enablePerformanceFee) {
247- // subSharesFee is the amount of subVault shares set aside as performance fee
248- uint256 subSharesFee = _convertToSubShares (_subVault, profit, rounding == MathUpgradeable.Rounding.Up ? MathUpgradeable.Rounding.Down : MathUpgradeable.Rounding.Up);
249- totalSubShares -= subSharesFee ;
243+ if (enablePerformanceFee) {
244+ // since we use totalSubShares in the numerator of the final calculation,
245+ // and we are subtracting profit from it, we should use the opposite rounding direction for profit
246+ totalSubShares -= totalProfitInSubVaultShares ( _flipRounding (rounding)) ;
250247 }
251248
252249 // totalSubShares * shares / totalMasterShares
253250 uint256 subShares = totalSubShares.mulDiv (shares, _totalSupply, rounding);
254251
255- return _convertToSubAssets (_subVault, subShares, rounding);
252+ return _subVaultSharesToAssets (_subVault, subShares, rounding);
256253 }
257254
258- function _convertToSubShares (IERC4626 _subVault , uint256 assets , MathUpgradeable.Rounding rounding ) internal view returns (uint256 subShares ) {
255+ function _assetsToSubVaultShares (IERC4626 _subVault , uint256 assets , MathUpgradeable.Rounding rounding ) internal view returns (uint256 subShares ) {
259256 return rounding == MathUpgradeable.Rounding.Up ? _subVault.previewWithdraw (assets) : _subVault.previewDeposit (assets);
260257 }
261258
262- function _convertToSubAssets (IERC4626 _subVault , uint256 subShares , MathUpgradeable.Rounding rounding ) internal view returns (uint256 assets ) {
259+ function _subVaultSharesToAssets (IERC4626 _subVault , uint256 subShares , MathUpgradeable.Rounding rounding ) internal view returns (uint256 assets ) {
263260 return rounding == MathUpgradeable.Rounding.Up ? _subVault.previewMint (subShares) : _subVault.previewRedeem (subShares);
264261 }
265262
266263 function _min (uint256 a , uint256 b ) internal pure returns (uint256 ) {
267264 return a <= b ? a : b;
268265 }
269266
267+ function _flipRounding (MathUpgradeable.Rounding rounding ) internal pure returns (MathUpgradeable.Rounding) {
268+ return rounding == MathUpgradeable.Rounding.Up ? MathUpgradeable.Rounding.Down : MathUpgradeable.Rounding.Up;
269+ }
270+
271+
272+ function totalProfit (MathUpgradeable.Rounding rounding ) public view returns (uint256 ) {
273+ IERC4626 _subVault = subVault;
274+ if (address (_subVault) == address (0 )) {
275+ uint256 _tokenBalance = IERC20 (asset ()).balanceOf (address (this ));
276+ return _tokenBalance > totalPrincipal ? _tokenBalance - totalPrincipal : 0 ;
277+ }
278+ uint256 totalSubShares = _subVault.balanceOf (address (this ));
279+ uint256 _totalAssets = _subVaultSharesToAssets (_subVault, totalSubShares, rounding);
280+ uint256 _totalPrincipal = totalPrincipal;
281+ return _totalAssets > _totalPrincipal ? _totalAssets - _totalPrincipal : 0 ;
282+ }
270283
271- function totalProfit () public view returns (uint256 ) {
272- uint256 _totalAssets = totalAssets ();
273- return _totalAssets > totalPrincipal ? _totalAssets - totalPrincipal : 0 ;
284+ function totalProfitInSubVaultShares (MathUpgradeable.Rounding rounding ) public view returns (uint256 ) {
285+ IERC4626 _subVault = subVault;
286+ if (address (_subVault) == address (0 )) {
287+ revert ("Subvault not set " );
288+ }
289+ uint256 profitAssets = totalProfit (rounding);
290+ if (profitAssets == 0 ) {
291+ return 0 ;
292+ }
293+ return _assetsToSubVaultShares (_subVault, profitAssets, rounding);
274294 }
275295
276296 /**
0 commit comments