Skip to content

Commit ea515ec

Browse files
refactor: improve AMM error handling and mathematical calculations
- Enhanced error handling in swap and deposit liquidity instructions - Improved mathematical utilities for precision and safety - Updated test suites to cover edge cases with decimals and slippage - Refined type definitions in fuzz tests
1 parent cb5436c commit ea515ec

9 files changed

Lines changed: 105 additions & 63 deletions

File tree

programs/amm/src/errors.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,4 @@ pub enum AmmError {
2626
DecimalsTooLarge,
2727
#[msg("Fee exceeds maximum allowed")]
2828
FeeExceedsMaximum,
29-
#[msg("Slippage tolerance exceeded")]
30-
SlippageExceeded,
3129
}

programs/amm/src/instructions/deposit_liquidity.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ impl<'info> DepositLiquidity<'info> {
137137
self.liquidity_pool.amount_mint_b,
138138
)?;
139139

140-
require!(liquidity >= min_lp_out, AmmError::SlippageExceeded);
140+
require!(
141+
liquidity >= min_lp_out,
142+
AmmError::InsufficientLiquidityMinted
143+
);
141144

142145
let signer_seeds: &[&[&[u8]]] = &[&[
143146
LIQUIDITY_POOL_SEED,

programs/amm/src/instructions/swap.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,13 @@ pub struct Swap<'info> {
6565
}
6666

6767
impl<'info> Swap<'info> {
68-
pub fn swap(&mut self, pool_id: u64, params: SwapParams, slippage_limit: u64) -> Result<()> {
68+
pub fn swap(&mut self, pool_id: u64, params: SwapParams) -> Result<()> {
6969
// calculate and transfer input mint from signer to vault
7070
let input_is_mint_a = self.liquidity_pool.mint_a == self.input_mint.key();
7171

7272
let (input_amount, output_amount, protocol_fee_amount) =
7373
self.calculate_amounts(&params, input_is_mint_a)?;
7474

75-
if params.is_exact_in() {
76-
require!(output_amount >= slippage_limit, AmmError::SlippageExceeded);
77-
} else if slippage_limit > 0 {
78-
require!(input_amount <= slippage_limit, AmmError::SlippageExceeded);
79-
}
80-
8175
msg!(
8276
"Swapping {} of input mint for {} of output mint",
8377
input_amount,
@@ -175,8 +169,14 @@ impl<'info> Swap<'info> {
175169

176170
#[derive(AnchorSerialize, AnchorDeserialize)]
177171
pub enum SwapParams {
178-
ExactIn { input_amount: u64 },
179-
ExactOut { output_amount: u64 },
172+
ExactIn {
173+
input_amount: u64,
174+
min_output_amount: u64,
175+
},
176+
ExactOut {
177+
output_amount: u64,
178+
max_input_amount: u64,
179+
},
180180
}
181181

182182
impl SwapParams {

programs/amm/src/lib.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,16 @@ pub mod amm {
3434
amount_b: u64,
3535
min_lp_out: u64,
3636
) -> Result<()> {
37-
ctx.accounts.deposit_liquidity(pool_id, amount_a, amount_b, min_lp_out)
37+
ctx.accounts
38+
.deposit_liquidity(pool_id, amount_a, amount_b, min_lp_out)
3839
}
3940

4041
pub fn redeem_lp(ctx: Context<RedeemLp>, pool_id: u64, lp_amount: u64) -> Result<()> {
4142
ctx.accounts.redeem_lp(pool_id, lp_amount)
4243
}
4344

44-
pub fn swap(ctx: Context<Swap>, pool_id: u64, params: SwapParams, slippage_limit: u64) -> Result<()> {
45-
ctx.accounts.swap(pool_id, params, slippage_limit)
45+
pub fn swap(ctx: Context<Swap>, pool_id: u64, params: SwapParams) -> Result<()> {
46+
ctx.accounts.swap(pool_id, params)
4647
}
4748

4849
pub fn withdraw_protocol_fees(ctx: Context<WithdrawProtocolFees>, pool_id: u64) -> Result<()> {

programs/amm/src/utils/math.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,10 @@ pub fn calculate_swap_amounts(
314314
.ok_or(AmmError::MathUnderflow)?;
315315

316316
match swap_params {
317-
SwapParams::ExactIn { input_amount } => {
317+
SwapParams::ExactIn {
318+
input_amount,
319+
min_output_amount,
320+
} => {
318321
require!(*input_amount > 0, AmmError::InsufficientInputAmount);
319322
require!(
320323
*input_amount < reserve_input,
@@ -345,6 +348,12 @@ pub fn calculate_swap_amounts(
345348
.try_into()
346349
.map_err(|_| AmmError::MathOverflow)?;
347350

351+
// slippage check
352+
require!(
353+
output_amount >= *min_output_amount,
354+
AmmError::InsufficientOutputAmount
355+
);
356+
348357
// Calcular el fee del protocolo del input_amount
349358
let protocol_fee_amount = (*input_amount as u128)
350359
.checked_mul(protocol_fee_bps as u128)
@@ -356,7 +365,10 @@ pub fn calculate_swap_amounts(
356365

357366
Ok((*input_amount, output_amount, protocol_fee_amount))
358367
}
359-
SwapParams::ExactOut { output_amount } => {
368+
SwapParams::ExactOut {
369+
output_amount,
370+
max_input_amount,
371+
} => {
360372
require!(*output_amount > 0, AmmError::InsufficientOutputAmount);
361373
require!(
362374
*output_amount < reserve_output,
@@ -383,6 +395,14 @@ pub fn calculate_swap_amounts(
383395
.try_into()
384396
.map_err(|_| AmmError::MathOverflow)?;
385397

398+
// slippage check
399+
if *max_input_amount > 0 {
400+
require!(
401+
input_amount <= *max_input_amount,
402+
AmmError::ExcessiveInputAmount
403+
);
404+
}
405+
386406
// Calcular el fee del protocolo del input_amount
387407
let protocol_fee_amount = (input_amount as u128)
388408
.checked_mul(protocol_fee_bps as u128)

tests/amm-decimals-local.test.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,14 @@ describe("amm - decimals tests", () => {
445445
const [poolBefore] = await getLiquidityPoolAccount(poolId);
446446

447447
const inputAmount = 10_000_000; // 10 USDC
448+
const minOutputAmount = 0;
449+
450+
const param: anchor.IdlTypes<Amm>["swapParams"] = {
451+
exactIn: { inputAmount: bn(inputAmount), minOutputAmount: bn(minOutputAmount) },
452+
};
448453

449454
const tx = await program.methods
450-
.swap(bn(poolId), { exactIn: { inputAmount: bn(inputAmount) } }, bn(0))
455+
.swap(bn(poolId), param)
451456
.accounts({
452457
tokenProgram: TOKEN_PROGRAM_ID,
453458
signer: liquidityProvider.publicKey,
@@ -469,9 +474,14 @@ describe("amm - decimals tests", () => {
469474
const [poolBefore] = await getLiquidityPoolAccount(poolId);
470475

471476
const inputAmount = 100_000_000; // 0.1 SOL
477+
const minOutputAmount = 0;
478+
479+
const param: anchor.IdlTypes<Amm>["swapParams"] = {
480+
exactIn: { inputAmount: bn(inputAmount), minOutputAmount: bn(minOutputAmount) },
481+
};
472482

473483
const tx = await program.methods
474-
.swap(bn(poolId), { exactIn: { inputAmount: bn(inputAmount) } }, bn(0))
484+
.swap(bn(poolId), param)
475485
.accounts({
476486
tokenProgram: TOKEN_PROGRAM_ID,
477487
signer: liquidityProvider.publicKey,

tests/amm-local.test.ts

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -349,18 +349,18 @@ describe("amm", () => {
349349

350350
it("`swap` method using 'exact in' param with mint A!", async () => {
351351
const INPUT_AMOUNT = 10_000_000; // 10 token
352-
const SLIPPAGE_LIMIT = 0; // expressed in amount, no bps
352+
const MIN_OUTPUT_AMOUNT = 0;
353353

354354
const param: anchor.IdlTypes<Amm>["swapParams"] = {
355-
exactIn: { inputAmount: bn(INPUT_AMOUNT) },
355+
exactIn: { inputAmount: bn(INPUT_AMOUNT), minOutputAmount: bn(MIN_OUTPUT_AMOUNT) },
356356
};
357357

358358
const [prevLiquidityPool] = await getLiquidityPoolAccount(poolId);
359359
console.log("A amount in vault:", prevLiquidityPool.amountMintA);
360360
console.log("B amount in vault:", prevLiquidityPool.amountMintB);
361361

362362
const tx = await program.methods
363-
.swap(bn(poolId), param, bn(SLIPPAGE_LIMIT))
363+
.swap(bn(poolId), param)
364364
.accounts({
365365
signer: user.publicKey,
366366
inputMint: mintA,
@@ -397,18 +397,18 @@ describe("amm", () => {
397397

398398
it("`swap` method using 'exact out' param with mint A!", async () => {
399399
const OUTPUT_AMOUNT = 20_000_000; // 2 token
400-
const SLIPPAGE_LIMIT = 0; // expressed in amount, no bps
400+
const MAX_INPUT_AMOUNT = 0;
401401

402402
const param: anchor.IdlTypes<Amm>["swapParams"] = {
403-
exactOut: { outputAmount: bn(OUTPUT_AMOUNT) },
403+
exactOut: { outputAmount: bn(OUTPUT_AMOUNT), maxInputAmount: bn(MAX_INPUT_AMOUNT) },
404404
};
405405

406406
const [prevLiquidityPool] = await getLiquidityPoolAccount(poolId);
407407
console.log("A amount in vault:", prevLiquidityPool.amountMintA);
408408
console.log("B amount in vault:", prevLiquidityPool.amountMintB);
409409

410410
const tx = await program.methods
411-
.swap(bn(poolId), param, bn(SLIPPAGE_LIMIT))
411+
.swap(bn(poolId), param)
412412
.accounts({
413413
signer: user.publicKey,
414414
inputMint: mintA,
@@ -445,18 +445,18 @@ describe("amm", () => {
445445

446446
it("`swap` method using 'exact in' param with mint B!", async () => {
447447
const INPUT_AMOUNT = 10_000_000; // 10 token
448-
const SLIPPAGE_LIMIT = 0; // expressed in amount, no bps
448+
const MIN_OUTPUT_AMOUNT = 0;
449449

450450
const param: anchor.IdlTypes<Amm>["swapParams"] = {
451-
exactIn: { inputAmount: bn(INPUT_AMOUNT) },
451+
exactIn: { inputAmount: bn(INPUT_AMOUNT), minOutputAmount: bn(MIN_OUTPUT_AMOUNT) },
452452
};
453453

454454
const [prevLiquidityPool] = await getLiquidityPoolAccount(poolId);
455455
console.log("A amount in vault:", prevLiquidityPool.amountMintA);
456456
console.log("B amount in vault:", prevLiquidityPool.amountMintB);
457457

458458
const tx = await program.methods
459-
.swap(bn(poolId), param, bn(SLIPPAGE_LIMIT))
459+
.swap(bn(poolId), param)
460460
.accounts({
461461
signer: user.publicKey,
462462
inputMint: mintB,
@@ -493,18 +493,18 @@ describe("amm", () => {
493493

494494
it("`swap` method using 'exact out' param with mint B!", async () => {
495495
const OUTPUT_AMOUNT = 2_000_000; // 2 token
496-
const SLIPPAGE_LIMIT = 0; // expressed in amount, no bps
496+
const MAX_INPUT_AMOUNT = 0;
497497

498498
const param: anchor.IdlTypes<Amm>["swapParams"] = {
499-
exactOut: { outputAmount: bn(OUTPUT_AMOUNT) },
499+
exactOut: { outputAmount: bn(OUTPUT_AMOUNT), maxInputAmount: bn(MAX_INPUT_AMOUNT) },
500500
};
501501

502502
const [prevLiquidityPool] = await getLiquidityPoolAccount(poolId);
503503
console.log("A amount in vault:", prevLiquidityPool.amountMintA);
504504
console.log("B amount in vault:", prevLiquidityPool.amountMintB);
505505

506506
const tx = await program.methods
507-
.swap(bn(poolId), param, bn(SLIPPAGE_LIMIT))
507+
.swap(bn(poolId), param)
508508
.accounts({
509509
signer: user.publicKey,
510510
inputMint: mintB,

0 commit comments

Comments
 (0)