diff --git a/mp_root_n.c b/mp_root_n.c index d904df88..e3c862b3 100644 --- a/mp_root_n.c +++ b/mp_root_n.c @@ -14,7 +14,7 @@ */ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) { - mp_int t1, t2, t3, a_; + mp_int t1, t2, t3, a_, d; int ilog2; mp_err err; @@ -27,7 +27,7 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) return MP_VAL; } - if ((err = mp_init_multi(&t1, &t2, &t3, NULL)) != MP_OKAY) { + if ((err = mp_init_multi(&t1, &t2, &t3, &d, NULL)) != MP_OKAY) { return err; } @@ -35,7 +35,7 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) a_ = *a; a_.sign = MP_ZPOS; - /* Compute seed: 2^(log_2(n)/b + 2)*/ + /* Compute seed: 2^(log_2(n)/b + 1)*/ ilog2 = mp_count_bits(a); /* @@ -57,7 +57,7 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) err = MP_OKAY; goto LBL_ERR; } - ilog2 = ilog2 / b; + ilog2 = (ilog2 - 1) / b; if (ilog2 == 0) { mp_set(c, 1uL); c->sign = a->sign; @@ -65,13 +65,13 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) goto LBL_ERR; } /* Start value must be larger than root */ - ilog2 += 2; - if ((err = mp_2expt(&t2,ilog2)) != MP_OKAY) goto LBL_ERR; + ilog2 += 1; + if ((err = mp_2expt(&t1, ilog2)) != MP_OKAY) goto LBL_ERR; + do { - /* t1 = t2 */ - if ((err = mp_copy(&t2, &t1)) != MP_OKAY) goto LBL_ERR; + mp_ord cmp; - /* t2 = t1 - ((t1**b - a) / (b * t1**(b-1))) */ + /* t2 = t1 - ceiling(((t1**b - a) / (b * t1**(b-1)))) */ /* t3 = t1**(b-1) */ if ((err = mp_expt_n(&t1, b - 1, &t3)) != MP_OKAY) goto LBL_ERR; @@ -80,6 +80,14 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) /* t2 = t1**b */ if ((err = mp_mul(&t3, &t1, &t2)) != MP_OKAY) goto LBL_ERR; + cmp = mp_cmp(&t2, &a_); + if (cmp == MP_EQ || cmp == MP_LT) { + err = MP_OKAY; + mp_exch(&t1, c); + c->sign = a->sign; + goto LBL_ERR; + } + /* t2 = t1**b - a */ if ((err = mp_sub(&t2, &a_, &t2)) != MP_OKAY) goto LBL_ERR; @@ -88,35 +96,19 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) if ((err = mp_mul_d(&t3, (mp_digit)b, &t3)) != MP_OKAY) goto LBL_ERR; /* t3 = (t1**b - a)/(b * t1**(b-1)) */ - if ((err = mp_div(&t2, &t3, &t3, NULL)) != MP_OKAY) goto LBL_ERR; + if ((err = mp_div(&t2, &t3, &t3, &d)) != MP_OKAY) goto LBL_ERR; + /* round up t3 - so t1 will be rounded down */ + if(!mp_iszero(&d)) { + if ((err = mp_add_d(&t3, 1uL, &t3)) != MP_OKAY) goto LBL_ERR; + } - if ((err = mp_sub(&t1, &t3, &t2)) != MP_OKAY) goto LBL_ERR; + /* t1 = t1 - t3 */ + if ((err = mp_sub(&t1, &t3, &t1)) != MP_OKAY) goto LBL_ERR; - /* - Number of rounds is at most log_2(root). If it is more it - got stuck, so break out of the loop and do the rest manually. - */ - if (ilog2-- == 0) { - break; - } - } while (mp_cmp(&t1, &t2) != MP_EQ); + /* while t3 != 1 */ + } while (!((t3.used == 1u) && (t3.dp[0] == 1u))); /* result can be off by a few so check */ - /* Loop beneath can overshoot by one if found root is smaller than actual root */ - for (;;) { - mp_ord cmp; - if ((err = mp_expt_n(&t1, b, &t2)) != MP_OKAY) goto LBL_ERR; - cmp = mp_cmp(&t2, &a_); - if (cmp == MP_EQ) { - err = MP_OKAY; - goto LBL_ERR; - } - if (cmp == MP_LT) { - if ((err = mp_add_d(&t1, 1uL, &t1)) != MP_OKAY) goto LBL_ERR; - } else { - break; - } - } /* correct overshoot from above or from recurrence */ for (;;) { if ((err = mp_expt_n(&t1, b, &t2)) != MP_OKAY) goto LBL_ERR; @@ -134,7 +126,7 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) c->sign = a->sign; LBL_ERR: - mp_clear_multi(&t1, &t2, &t3, NULL); + mp_clear_multi(&t1, &t2, &t3, &d, NULL); return err; }