From 22faf6ad3bcc0ae478a9a3e2d8e35888d88d6ce8 Mon Sep 17 00:00:00 2001 From: Stefan Krah Date: Tue, 9 Jun 2020 01:55:47 +0200 Subject: [3.7] Revert bpo-39576: Prevent memory error for overly optimistic precisions (GH-20748) This reverts commit c6f95543b4832c3f0170179da39bcf99b40a7aa8. --- Lib/test/test_decimal.py | 35 --------- Modules/_decimal/libmpdec/mpdecimal.c | 77 +------------------ Modules/_decimal/tests/deccheck.py | 139 +--------------------------------- 3 files changed, 6 insertions(+), 245 deletions(-) diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 0e9cd3095c8..1f37b5372a3 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -5476,41 +5476,6 @@ def __abs__(self): self.assertEqual(Decimal.from_float(cls(101.1)), Decimal.from_float(101.1)) - def test_maxcontext_exact_arith(self): - - # Make sure that exact operations do not raise MemoryError due - # to huge intermediate values when the context precision is very - # large. - - # The following functions fill the available precision and are - # therefore not suitable for large precisions (by design of the - # specification). - MaxContextSkip = ['logical_invert', 'next_minus', 'next_plus', - 'logical_and', 'logical_or', 'logical_xor', - 'next_toward', 'rotate', 'shift'] - - Decimal = C.Decimal - Context = C.Context - localcontext = C.localcontext - - # Here only some functions that are likely candidates for triggering a - # MemoryError are tested. deccheck.py has an exhaustive test. - maxcontext = Context(prec=C.MAX_PREC, Emin=C.MIN_EMIN, Emax=C.MAX_EMAX) - with localcontext(maxcontext): - self.assertEqual(Decimal(0).exp(), 1) - self.assertEqual(Decimal(1).ln(), 0) - self.assertEqual(Decimal(1).log10(), 0) - self.assertEqual(Decimal(10**2).log10(), 2) - self.assertEqual(Decimal(10**223).log10(), 223) - self.assertEqual(Decimal(10**19).logb(), 19) - self.assertEqual(Decimal(4).sqrt(), 2) - self.assertEqual(Decimal("40E9").sqrt(), Decimal('2.0E+5')) - self.assertEqual(divmod(Decimal(10), 3), (3, 1)) - self.assertEqual(Decimal(10) // 3, 3) - self.assertEqual(Decimal(4) / 2, 2) - self.assertEqual(Decimal(400) ** -1, Decimal('0.0025')) - - @requires_docstrings @unittest.skipUnless(C, "test requires C version") class SignatureTest(unittest.TestCase): diff --git a/Modules/_decimal/libmpdec/mpdecimal.c b/Modules/_decimal/libmpdec/mpdecimal.c index 0986edb576a..bfa8bb343e6 100644 --- a/Modules/_decimal/libmpdec/mpdecimal.c +++ b/Modules/_decimal/libmpdec/mpdecimal.c @@ -3781,43 +3781,6 @@ mpd_qdiv(mpd_t *q, const mpd_t *a, const mpd_t *b, const mpd_context_t *ctx, uint32_t *status) { _mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status); - - if (*status & MPD_Malloc_error) { - /* Inexact quotients (the usual case) fill the entire context precision, - * which can lead to malloc() failures for very high precisions. Retry - * the operation with a lower precision in case the result is exact. - * - * We need an upper bound for the number of digits of a_coeff / b_coeff - * when the result is exact. If a_coeff' * 1 / b_coeff' is in lowest - * terms, then maxdigits(a_coeff') + maxdigits(1 / b_coeff') is a suitable - * bound. - * - * 1 / b_coeff' is exact iff b_coeff' exclusively has prime factors 2 or 5. - * The largest amount of digits is generated if b_coeff' is a power of 2 or - * a power of 5 and is less than or equal to log5(b_coeff') <= log2(b_coeff'). - * - * We arrive at a total upper bound: - * - * maxdigits(a_coeff') + maxdigits(1 / b_coeff') <= - * a->digits + log2(b_coeff) = - * a->digits + log10(b_coeff) / log10(2) <= - * a->digits + b->digits * 4; - */ - uint32_t workstatus = 0; - mpd_context_t workctx = *ctx; - workctx.prec = a->digits + b->digits * 4; - if (workctx.prec >= ctx->prec) { - return; /* No point in retrying, keep the original error. */ - } - - _mpd_qdiv(SET_IDEAL_EXP, q, a, b, &workctx, &workstatus); - if (workstatus == 0) { /* The result is exact, unrounded, normal etc. */ - *status = 0; - return; - } - - mpd_seterror(q, *status, status); - } } /* Internal function. */ @@ -7739,9 +7702,9 @@ mpd_qinvroot(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, /* END LIBMPDEC_ONLY */ /* Algorithm from decimal.py */ -static void -_mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, - uint32_t *status) +void +mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, + uint32_t *status) { mpd_context_t maxcontext; MPD_NEW_STATIC(c,0,0,0,0); @@ -7873,40 +7836,6 @@ _mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, goto out; } -void -mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, - uint32_t *status) -{ - _mpd_qsqrt(result, a, ctx, status); - - if (*status & (MPD_Malloc_error|MPD_Division_impossible)) { - /* The above conditions can occur at very high context precisions - * if intermediate values get too large. Retry the operation with - * a lower context precision in case the result is exact. - * - * If the result is exact, an upper bound for the number of digits - * is the number of digits in the input. - * - * NOTE: sqrt(40e9) = 2.0e+5 /\ digits(40e9) = digits(2.0e+5) = 2 - */ - uint32_t workstatus = 0; - mpd_context_t workctx = *ctx; - workctx.prec = a->digits; - - if (workctx.prec >= ctx->prec) { - return; /* No point in repeating this, keep the original error. */ - } - - _mpd_qsqrt(result, a, &workctx, &workstatus); - if (workstatus == 0) { - *status = 0; - return; - } - - mpd_seterror(result, *status, status); - } -} - /******************************************************************************/ /* Base conversions */ diff --git a/Modules/_decimal/tests/deccheck.py b/Modules/_decimal/tests/deccheck.py index 5cd5db57114..f907531e1ff 100644 --- a/Modules/_decimal/tests/deccheck.py +++ b/Modules/_decimal/tests/deccheck.py @@ -125,12 +125,6 @@ 'special': ('context.__reduce_ex__', 'context.create_decimal_from_float') } -# Functions that set no context flags but whose result can differ depending -# on prec, Emin and Emax. -MaxContextSkip = ['is_normal', 'is_subnormal', 'logical_invert', 'next_minus', - 'next_plus', 'number_class', 'logical_and', 'logical_or', - 'logical_xor', 'next_toward', 'rotate', 'shift'] - # Functions that require a restricted exponent range for reasonable runtimes. UnaryRestricted = [ '__ceil__', '__floor__', '__int__', '__trunc__', @@ -350,20 +344,6 @@ def __init__(self, funcname, operands): self.pex = RestrictedList() # Python exceptions for P.Decimal self.presults = RestrictedList() # P.Decimal results - # If the above results are exact, unrounded and not clamped, repeat - # the operation with a maxcontext to ensure that huge intermediate - # values do not cause a MemoryError. - self.with_maxcontext = False - self.maxcontext = context.c.copy() - self.maxcontext.prec = C.MAX_PREC - self.maxcontext.Emax = C.MAX_EMAX - self.maxcontext.Emin = C.MIN_EMIN - self.maxcontext.clear_flags() - - self.maxop = RestrictedList() # converted C.Decimal operands - self.maxex = RestrictedList() # Python exceptions for C.Decimal - self.maxresults = RestrictedList() # C.Decimal results - # ====================================================================== # SkipHandler: skip known discrepancies @@ -565,17 +545,13 @@ def function_as_string(t): if t.contextfunc: cargs = t.cop pargs = t.pop - maxargs = t.maxop cfunc = "c_func: %s(" % t.funcname pfunc = "p_func: %s(" % t.funcname - maxfunc = "max_func: %s(" % t.funcname else: cself, cargs = t.cop[0], t.cop[1:] pself, pargs = t.pop[0], t.pop[1:] - maxself, maxargs = t.maxop[0], t.maxop[1:] cfunc = "c_func: %s.%s(" % (repr(cself), t.funcname) pfunc = "p_func: %s.%s(" % (repr(pself), t.funcname) - maxfunc = "max_func: %s.%s(" % (repr(maxself), t.funcname) err = cfunc for arg in cargs: @@ -589,14 +565,6 @@ def function_as_string(t): err = err.rstrip(", ") err += ")" - if t.with_maxcontext: - err += "\n" - err += maxfunc - for arg in maxargs: - err += "%s, " % repr(arg) - err = err.rstrip(", ") - err += ")" - return err def raise_error(t): @@ -609,24 +577,9 @@ def raise_error(t): err = "Error in %s:\n\n" % t.funcname err += "input operands: %s\n\n" % (t.op,) err += function_as_string(t) - - err += "\n\nc_result: %s\np_result: %s\n" % (t.cresults, t.presults) - if t.with_maxcontext: - err += "max_result: %s\n\n" % (t.maxresults) - else: - err += "\n" - - err += "c_exceptions: %s\np_exceptions: %s\n" % (t.cex, t.pex) - if t.with_maxcontext: - err += "max_exceptions: %s\n\n" % t.maxex - else: - err += "\n" - - err += "%s\n" % str(t.context) - if t.with_maxcontext: - err += "%s\n" % str(t.maxcontext) - else: - err += "\n" + err += "\n\nc_result: %s\np_result: %s\n\n" % (t.cresults, t.presults) + err += "c_exceptions: %s\np_exceptions: %s\n\n" % (t.cex, t.pex) + err += "%s\n\n" % str(t.context) raise VerifyError(err) @@ -650,13 +603,6 @@ def raise_error(t): # are printed to stdout. # ====================================================================== -def all_nan(a): - if isinstance(a, C.Decimal): - return a.is_nan() - elif isinstance(a, tuple): - return all(all_nan(v) for v in a) - return False - def convert(t, convstr=True): """ t is the testset. At this stage the testset contains a tuple of operands t.op of various types. For decimal methods the first @@ -671,12 +617,10 @@ def convert(t, convstr=True): for i, op in enumerate(t.op): context.clear_status() - t.maxcontext.clear_flags() if op in RoundModes: t.cop.append(op) t.pop.append(op) - t.maxop.append(op) elif not t.contextfunc and i == 0 or \ convstr and isinstance(op, str): @@ -694,25 +638,11 @@ def convert(t, convstr=True): p = None pex = e.__class__ - try: - C.setcontext(t.maxcontext) - maxop = C.Decimal(op) - maxex = None - except (TypeError, ValueError, OverflowError) as e: - maxop = None - maxex = e.__class__ - finally: - C.setcontext(context.c) - t.cop.append(c) t.cex.append(cex) - t.pop.append(p) t.pex.append(pex) - t.maxop.append(maxop) - t.maxex.append(maxex) - if cex is pex: if str(c) != str(p) or not context.assert_eq_status(): raise_error(t) @@ -722,21 +652,14 @@ def convert(t, convstr=True): else: raise_error(t) - # The exceptions in the maxcontext operation can legitimately - # differ, only test that maxex implies cex: - if maxex is not None and cex is not maxex: - raise_error(t) - elif isinstance(op, Context): t.context = op t.cop.append(op.c) t.pop.append(op.p) - t.maxop.append(t.maxcontext) else: t.cop.append(op) t.pop.append(op) - t.maxop.append(op) return 1 @@ -750,7 +673,6 @@ def callfuncs(t): t.rc and t.rp are the results of the operation. """ context.clear_status() - t.maxcontext.clear_flags() try: if t.contextfunc: @@ -778,35 +700,6 @@ def callfuncs(t): t.rp = None t.pex.append(e.__class__) - # If the above results are exact, unrounded, normal etc., repeat the - # operation with a maxcontext to ensure that huge intermediate values - # do not cause a MemoryError. - if (t.funcname not in MaxContextSkip and - not context.c.flags[C.InvalidOperation] and - not context.c.flags[C.Inexact] and - not context.c.flags[C.Rounded] and - not context.c.flags[C.Subnormal] and - not context.c.flags[C.Clamped] and - not context.clamp and # results are padded to context.prec if context.clamp==1. - not any(isinstance(v, C.Context) for v in t.cop)): # another context is used. - t.with_maxcontext = True - try: - if t.contextfunc: - maxargs = t.maxop - t.rmax = getattr(t.maxcontext, t.funcname)(*maxargs) - else: - maxself = t.maxop[0] - maxargs = t.maxop[1:] - try: - C.setcontext(t.maxcontext) - t.rmax = getattr(maxself, t.funcname)(*maxargs) - finally: - C.setcontext(context.c) - t.maxex.append(None) - except (TypeError, ValueError, OverflowError, MemoryError) as e: - t.rmax = None - t.maxex.append(e.__class__) - def verify(t, stat): """ t is the testset. At this stage the testset contains the following tuples: @@ -821,9 +714,6 @@ def verify(t, stat): """ t.cresults.append(str(t.rc)) t.presults.append(str(t.rp)) - if t.with_maxcontext: - t.maxresults.append(str(t.rmax)) - if isinstance(t.rc, C.Decimal) and isinstance(t.rp, P.Decimal): # General case: both results are Decimals. t.cresults.append(t.rc.to_eng_string()) @@ -835,12 +725,6 @@ def verify(t, stat): t.presults.append(str(t.rp.imag)) t.presults.append(str(t.rp.real)) - if t.with_maxcontext and isinstance(t.rmax, C.Decimal): - t.maxresults.append(t.rmax.to_eng_string()) - t.maxresults.append(t.rmax.as_tuple()) - t.maxresults.append(str(t.rmax.imag)) - t.maxresults.append(str(t.rmax.real)) - nc = t.rc.number_class().lstrip('+-s') stat[nc] += 1 else: @@ -848,9 +732,6 @@ def verify(t, stat): if not isinstance(t.rc, tuple) and not isinstance(t.rp, tuple): if t.rc != t.rp: raise_error(t) - if t.with_maxcontext and not isinstance(t.rmax, tuple): - if t.rmax != t.rc: - raise_error(t) stat[type(t.rc).__name__] += 1 # The return value lists must be equal. @@ -863,20 +744,6 @@ def verify(t, stat): if not t.context.assert_eq_status(): raise_error(t) - if t.with_maxcontext: - # NaN payloads etc. depend on precision and clamp. - if all_nan(t.rc) and all_nan(t.rmax): - return - # The return value lists must be equal. - if t.maxresults != t.cresults: - raise_error(t) - # The Python exception lists (TypeError, etc.) must be equal. - if t.maxex != t.cex: - raise_error(t) - # The context flags must be equal. - if t.maxcontext.flags != t.context.c.flags: - raise_error(t) - # ====================================================================== # Main test loops -- cgit v1.2.3-65-gdbad