summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Krah <skrah@bytereef.org>2020-06-09 01:55:47 +0200
committerGitHub <noreply@github.com>2020-06-09 01:55:47 +0200
commit22faf6ad3bcc0ae478a9a3e2d8e35888d88d6ce8 (patch)
tree90601649b90e24c507f1582e3289e681db051271
parent[3.7] Revert bpo-39576: docs: set context for decimal arbitrary precision ari... (diff)
downloadcpython-22faf6ad3bcc0ae478a9a3e2d8e35888d88d6ce8.tar.gz
cpython-22faf6ad3bcc0ae478a9a3e2d8e35888d88d6ce8.tar.bz2
cpython-22faf6ad3bcc0ae478a9a3e2d8e35888d88d6ce8.zip
[3.7] Revert bpo-39576: Prevent memory error for overly optimistic precisions (GH-20748)
This reverts commit c6f95543b4832c3f0170179da39bcf99b40a7aa8.
-rw-r--r--Lib/test/test_decimal.py35
-rw-r--r--Modules/_decimal/libmpdec/mpdecimal.c77
-rw-r--r--Modules/_decimal/tests/deccheck.py139
3 files changed, 6 insertions, 245 deletions
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 0e9cd3095c8..1f37b5372a3 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -5476,41 +5476,6 @@ class CWhitebox(unittest.TestCase):
self.assertEqual(Decimal.from_float(cls(101.1)),
Decimal.from_float(101.1))
- def test_maxcontext_exact_arith(self):
-
- # Make sure that exact operations do not raise MemoryError due
- # to huge intermediate values when the context precision is very
- # large.
-
- # The following functions fill the available precision and are
- # therefore not suitable for large precisions (by design of the
- # specification).
- MaxContextSkip = ['logical_invert', 'next_minus', 'next_plus',
- 'logical_and', 'logical_or', 'logical_xor',
- 'next_toward', 'rotate', 'shift']
-
- Decimal = C.Decimal
- Context = C.Context
- localcontext = C.localcontext
-
- # Here only some functions that are likely candidates for triggering a
- # MemoryError are tested. deccheck.py has an exhaustive test.
- maxcontext = Context(prec=C.MAX_PREC, Emin=C.MIN_EMIN, Emax=C.MAX_EMAX)
- with localcontext(maxcontext):
- self.assertEqual(Decimal(0).exp(), 1)
- self.assertEqual(Decimal(1).ln(), 0)
- self.assertEqual(Decimal(1).log10(), 0)
- self.assertEqual(Decimal(10**2).log10(), 2)
- self.assertEqual(Decimal(10**223).log10(), 223)
- self.assertEqual(Decimal(10**19).logb(), 19)
- self.assertEqual(Decimal(4).sqrt(), 2)
- self.assertEqual(Decimal("40E9").sqrt(), Decimal('2.0E+5'))
- self.assertEqual(divmod(Decimal(10), 3), (3, 1))
- self.assertEqual(Decimal(10) // 3, 3)
- self.assertEqual(Decimal(4) / 2, 2)
- self.assertEqual(Decimal(400) ** -1, Decimal('0.0025'))
-
-
@requires_docstrings
@unittest.skipUnless(C, "test requires C version")
class SignatureTest(unittest.TestCase):
diff --git a/Modules/_decimal/libmpdec/mpdecimal.c b/Modules/_decimal/libmpdec/mpdecimal.c
index 0986edb576a..bfa8bb343e6 100644
--- a/Modules/_decimal/libmpdec/mpdecimal.c
+++ b/Modules/_decimal/libmpdec/mpdecimal.c
@@ -3781,43 +3781,6 @@ mpd_qdiv(mpd_t *q, const mpd_t *a, const mpd_t *b,
const mpd_context_t *ctx, uint32_t *status)
{
_mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status);
-
- if (*status & MPD_Malloc_error) {
- /* Inexact quotients (the usual case) fill the entire context precision,
- * which can lead to malloc() failures for very high precisions. Retry
- * the operation with a lower precision in case the result is exact.
- *
- * We need an upper bound for the number of digits of a_coeff / b_coeff
- * when the result is exact. If a_coeff' * 1 / b_coeff' is in lowest
- * terms, then maxdigits(a_coeff') + maxdigits(1 / b_coeff') is a suitable
- * bound.
- *
- * 1 / b_coeff' is exact iff b_coeff' exclusively has prime factors 2 or 5.
- * The largest amount of digits is generated if b_coeff' is a power of 2 or
- * a power of 5 and is less than or equal to log5(b_coeff') <= log2(b_coeff').
- *
- * We arrive at a total upper bound:
- *
- * maxdigits(a_coeff') + maxdigits(1 / b_coeff') <=
- * a->digits + log2(b_coeff) =
- * a->digits + log10(b_coeff) / log10(2) <=
- * a->digits + b->digits * 4;
- */
- uint32_t workstatus = 0;
- mpd_context_t workctx = *ctx;
- workctx.prec = a->digits + b->digits * 4;
- if (workctx.prec >= ctx->prec) {
- return; /* No point in retrying, keep the original error. */
- }
-
- _mpd_qdiv(SET_IDEAL_EXP, q, a, b, &workctx, &workstatus);
- if (workstatus == 0) { /* The result is exact, unrounded, normal etc. */
- *status = 0;
- return;
- }
-
- mpd_seterror(q, *status, status);
- }
}
/* Internal function. */
@@ -7739,9 +7702,9 @@ mpd_qinvroot(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
/* END LIBMPDEC_ONLY */
/* Algorithm from decimal.py */
-static void
-_mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
- uint32_t *status)
+void
+mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
+ uint32_t *status)
{
mpd_context_t maxcontext;
MPD_NEW_STATIC(c,0,0,0,0);
@@ -7873,40 +7836,6 @@ malloc_error:
goto out;
}
-void
-mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
- uint32_t *status)
-{
- _mpd_qsqrt(result, a, ctx, status);
-
- if (*status & (MPD_Malloc_error|MPD_Division_impossible)) {
- /* The above conditions can occur at very high context precisions
- * if intermediate values get too large. Retry the operation with
- * a lower context precision in case the result is exact.
- *
- * If the result is exact, an upper bound for the number of digits
- * is the number of digits in the input.
- *
- * NOTE: sqrt(40e9) = 2.0e+5 /\ digits(40e9) = digits(2.0e+5) = 2
- */
- uint32_t workstatus = 0;
- mpd_context_t workctx = *ctx;
- workctx.prec = a->digits;
-
- if (workctx.prec >= ctx->prec) {
- return; /* No point in repeating this, keep the original error. */
- }
-
- _mpd_qsqrt(result, a, &workctx, &workstatus);
- if (workstatus == 0) {
- *status = 0;
- return;
- }
-
- mpd_seterror(result, *status, status);
- }
-}
-
/******************************************************************************/
/* Base conversions */
diff --git a/Modules/_decimal/tests/deccheck.py b/Modules/_decimal/tests/deccheck.py
index 5cd5db57114..f907531e1ff 100644
--- a/Modules/_decimal/tests/deccheck.py
+++ b/Modules/_decimal/tests/deccheck.py
@@ -125,12 +125,6 @@ ContextFunctions = {
'special': ('context.__reduce_ex__', 'context.create_decimal_from_float')
}
-# Functions that set no context flags but whose result can differ depending
-# on prec, Emin and Emax.
-MaxContextSkip = ['is_normal', 'is_subnormal', 'logical_invert', 'next_minus',
- 'next_plus', 'number_class', 'logical_and', 'logical_or',
- 'logical_xor', 'next_toward', 'rotate', 'shift']
-
# Functions that require a restricted exponent range for reasonable runtimes.
UnaryRestricted = [
'__ceil__', '__floor__', '__int__', '__trunc__',
@@ -350,20 +344,6 @@ class TestSet(object):
self.pex = RestrictedList() # Python exceptions for P.Decimal
self.presults = RestrictedList() # P.Decimal results
- # If the above results are exact, unrounded and not clamped, repeat
- # the operation with a maxcontext to ensure that huge intermediate
- # values do not cause a MemoryError.
- self.with_maxcontext = False
- self.maxcontext = context.c.copy()
- self.maxcontext.prec = C.MAX_PREC
- self.maxcontext.Emax = C.MAX_EMAX
- self.maxcontext.Emin = C.MIN_EMIN
- self.maxcontext.clear_flags()
-
- self.maxop = RestrictedList() # converted C.Decimal operands
- self.maxex = RestrictedList() # Python exceptions for C.Decimal
- self.maxresults = RestrictedList() # C.Decimal results
-
# ======================================================================
# SkipHandler: skip known discrepancies
@@ -565,17 +545,13 @@ def function_as_string(t):
if t.contextfunc:
cargs = t.cop
pargs = t.pop
- maxargs = t.maxop
cfunc = "c_func: %s(" % t.funcname
pfunc = "p_func: %s(" % t.funcname
- maxfunc = "max_func: %s(" % t.funcname
else:
cself, cargs = t.cop[0], t.cop[1:]
pself, pargs = t.pop[0], t.pop[1:]
- maxself, maxargs = t.maxop[0], t.maxop[1:]
cfunc = "c_func: %s.%s(" % (repr(cself), t.funcname)
pfunc = "p_func: %s.%s(" % (repr(pself), t.funcname)
- maxfunc = "max_func: %s.%s(" % (repr(maxself), t.funcname)
err = cfunc
for arg in cargs:
@@ -589,14 +565,6 @@ def function_as_string(t):
err = err.rstrip(", ")
err += ")"
- if t.with_maxcontext:
- err += "\n"
- err += maxfunc
- for arg in maxargs:
- err += "%s, " % repr(arg)
- err = err.rstrip(", ")
- err += ")"
-
return err
def raise_error(t):
@@ -609,24 +577,9 @@ def raise_error(t):
err = "Error in %s:\n\n" % t.funcname
err += "input operands: %s\n\n" % (t.op,)
err += function_as_string(t)
-
- err += "\n\nc_result: %s\np_result: %s\n" % (t.cresults, t.presults)
- if t.with_maxcontext:
- err += "max_result: %s\n\n" % (t.maxresults)
- else:
- err += "\n"
-
- err += "c_exceptions: %s\np_exceptions: %s\n" % (t.cex, t.pex)
- if t.with_maxcontext:
- err += "max_exceptions: %s\n\n" % t.maxex
- else:
- err += "\n"
-
- err += "%s\n" % str(t.context)
- if t.with_maxcontext:
- err += "%s\n" % str(t.maxcontext)
- else:
- err += "\n"
+ err += "\n\nc_result: %s\np_result: %s\n\n" % (t.cresults, t.presults)
+ err += "c_exceptions: %s\np_exceptions: %s\n\n" % (t.cex, t.pex)
+ err += "%s\n\n" % str(t.context)
raise VerifyError(err)
@@ -650,13 +603,6 @@ def raise_error(t):
# are printed to stdout.
# ======================================================================
-def all_nan(a):
- if isinstance(a, C.Decimal):
- return a.is_nan()
- elif isinstance(a, tuple):
- return all(all_nan(v) for v in a)
- return False
-
def convert(t, convstr=True):
""" t is the testset. At this stage the testset contains a tuple of
operands t.op of various types. For decimal methods the first
@@ -671,12 +617,10 @@ def convert(t, convstr=True):
for i, op in enumerate(t.op):
context.clear_status()
- t.maxcontext.clear_flags()
if op in RoundModes:
t.cop.append(op)
t.pop.append(op)
- t.maxop.append(op)
elif not t.contextfunc and i == 0 or \
convstr and isinstance(op, str):
@@ -694,25 +638,11 @@ def convert(t, convstr=True):
p = None
pex = e.__class__
- try:
- C.setcontext(t.maxcontext)
- maxop = C.Decimal(op)
- maxex = None
- except (TypeError, ValueError, OverflowError) as e:
- maxop = None
- maxex = e.__class__
- finally:
- C.setcontext(context.c)
-
t.cop.append(c)
t.cex.append(cex)
-
t.pop.append(p)
t.pex.append(pex)
- t.maxop.append(maxop)
- t.maxex.append(maxex)
-
if cex is pex:
if str(c) != str(p) or not context.assert_eq_status():
raise_error(t)
@@ -722,21 +652,14 @@ def convert(t, convstr=True):
else:
raise_error(t)
- # The exceptions in the maxcontext operation can legitimately
- # differ, only test that maxex implies cex:
- if maxex is not None and cex is not maxex:
- raise_error(t)
-
elif isinstance(op, Context):
t.context = op
t.cop.append(op.c)
t.pop.append(op.p)
- t.maxop.append(t.maxcontext)
else:
t.cop.append(op)
t.pop.append(op)
- t.maxop.append(op)
return 1
@@ -750,7 +673,6 @@ def callfuncs(t):
t.rc and t.rp are the results of the operation.
"""
context.clear_status()
- t.maxcontext.clear_flags()
try:
if t.contextfunc:
@@ -778,35 +700,6 @@ def callfuncs(t):
t.rp = None
t.pex.append(e.__class__)
- # If the above results are exact, unrounded, normal etc., repeat the
- # operation with a maxcontext to ensure that huge intermediate values
- # do not cause a MemoryError.
- if (t.funcname not in MaxContextSkip and
- not context.c.flags[C.InvalidOperation] and
- not context.c.flags[C.Inexact] and
- not context.c.flags[C.Rounded] and
- not context.c.flags[C.Subnormal] and
- not context.c.flags[C.Clamped] and
- not context.clamp and # results are padded to context.prec if context.clamp==1.
- not any(isinstance(v, C.Context) for v in t.cop)): # another context is used.
- t.with_maxcontext = True
- try:
- if t.contextfunc:
- maxargs = t.maxop
- t.rmax = getattr(t.maxcontext, t.funcname)(*maxargs)
- else:
- maxself = t.maxop[0]
- maxargs = t.maxop[1:]
- try:
- C.setcontext(t.maxcontext)
- t.rmax = getattr(maxself, t.funcname)(*maxargs)
- finally:
- C.setcontext(context.c)
- t.maxex.append(None)
- except (TypeError, ValueError, OverflowError, MemoryError) as e:
- t.rmax = None
- t.maxex.append(e.__class__)
-
def verify(t, stat):
""" t is the testset. At this stage the testset contains the following
tuples:
@@ -821,9 +714,6 @@ def verify(t, stat):
"""
t.cresults.append(str(t.rc))
t.presults.append(str(t.rp))
- if t.with_maxcontext:
- t.maxresults.append(str(t.rmax))
-
if isinstance(t.rc, C.Decimal) and isinstance(t.rp, P.Decimal):
# General case: both results are Decimals.
t.cresults.append(t.rc.to_eng_string())
@@ -835,12 +725,6 @@ def verify(t, stat):
t.presults.append(str(t.rp.imag))
t.presults.append(str(t.rp.real))
- if t.with_maxcontext and isinstance(t.rmax, C.Decimal):
- t.maxresults.append(t.rmax.to_eng_string())
- t.maxresults.append(t.rmax.as_tuple())
- t.maxresults.append(str(t.rmax.imag))
- t.maxresults.append(str(t.rmax.real))
-
nc = t.rc.number_class().lstrip('+-s')
stat[nc] += 1
else:
@@ -848,9 +732,6 @@ def verify(t, stat):
if not isinstance(t.rc, tuple) and not isinstance(t.rp, tuple):
if t.rc != t.rp:
raise_error(t)
- if t.with_maxcontext and not isinstance(t.rmax, tuple):
- if t.rmax != t.rc:
- raise_error(t)
stat[type(t.rc).__name__] += 1
# The return value lists must be equal.
@@ -863,20 +744,6 @@ def verify(t, stat):
if not t.context.assert_eq_status():
raise_error(t)
- if t.with_maxcontext:
- # NaN payloads etc. depend on precision and clamp.
- if all_nan(t.rc) and all_nan(t.rmax):
- return
- # The return value lists must be equal.
- if t.maxresults != t.cresults:
- raise_error(t)
- # The Python exception lists (TypeError, etc.) must be equal.
- if t.maxex != t.cex:
- raise_error(t)
- # The context flags must be equal.
- if t.maxcontext.flags != t.context.c.flags:
- raise_error(t)
-
# ======================================================================
# Main test loops