aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Krah <skrah@bytereef.org>2020-08-10 16:32:21 +0200
committerGitHub <noreply@github.com>2020-08-10 16:32:21 +0200
commit39042e00ab01d6521548c1b7cc6554c09f4389ff (patch)
treeea0d0a0cd9c73afe30a3573198e2d3a5844f200e /Modules
parentbpo-41514: Fix buggy IDLE test (GH-21808) (diff)
downloadcpython-39042e00ab01d6521548c1b7cc6554c09f4389ff.tar.gz
cpython-39042e00ab01d6521548c1b7cc6554c09f4389ff.tar.bz2
cpython-39042e00ab01d6521548c1b7cc6554c09f4389ff.zip
bpo-41324 Add a minimal decimal capsule API (#21519)
Diffstat (limited to 'Modules')
-rw-r--r--Modules/_decimal/_decimal.c185
-rw-r--r--Modules/_decimal/tests/deccheck.py74
-rw-r--r--Modules/_testcapimodule.c253
3 files changed, 505 insertions, 7 deletions
diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c
index fb4e020f126..e7c44acba02 100644
--- a/Modules/_decimal/_decimal.c
+++ b/Modules/_decimal/_decimal.c
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2008-2012 Stefan Krah. All rights reserved.
+ * Copyright (c) 2008-2020 Stefan Krah. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -33,6 +33,8 @@
#include <stdlib.h>
+#define CPYTHON_DECIMAL_MODULE
+#include "pydecimal.h"
#include "docstrings.h"
@@ -5555,6 +5557,160 @@ static PyTypeObject PyDecContext_Type =
};
+/****************************************************************************/
+/* C-API */
+/****************************************************************************/
+
+static void *_decimal_api[CPYTHON_DECIMAL_MAX_API];
+
+/* Simple API */
+static int
+PyDec_TypeCheck(const PyObject *v)
+{
+ return PyDec_Check(v);
+}
+
+static int
+PyDec_IsSpecial(const PyObject *v)
+{
+ if (!PyDec_Check(v)) {
+ PyErr_SetString(PyExc_TypeError,
+ "PyDec_IsSpecial: argument must be a Decimal");
+ return -1;
+ }
+
+ return mpd_isspecial(MPD(v));
+}
+
+static int
+PyDec_IsNaN(const PyObject *v)
+{
+ if (!PyDec_Check(v)) {
+ PyErr_SetString(PyExc_TypeError,
+ "PyDec_IsNaN: argument must be a Decimal");
+ return -1;
+ }
+
+ return mpd_isnan(MPD(v));
+}
+
+static int
+PyDec_IsInfinite(const PyObject *v)
+{
+ if (!PyDec_Check(v)) {
+ PyErr_SetString(PyExc_TypeError,
+ "PyDec_IsInfinite: argument must be a Decimal");
+ return -1;
+ }
+
+ return mpd_isinfinite(MPD(v));
+}
+
+static int64_t
+PyDec_GetDigits(const PyObject *v)
+{
+ if (!PyDec_Check(v)) {
+ PyErr_SetString(PyExc_TypeError,
+ "PyDec_GetDigits: argument must be a Decimal");
+ return -1;
+ }
+
+ return MPD(v)->digits;
+}
+
+static mpd_uint128_triple_t
+PyDec_AsUint128Triple(const PyObject *v)
+{
+ if (!PyDec_Check(v)) {
+ mpd_uint128_triple_t triple = { MPD_TRIPLE_ERROR, 0, 0, 0, 0 };
+ PyErr_SetString(PyExc_TypeError,
+ "PyDec_AsUint128Triple: argument must be a Decimal");
+ return triple;
+ }
+
+ return mpd_as_uint128_triple(MPD(v));
+}
+
+static PyObject *
+PyDec_FromUint128Triple(const mpd_uint128_triple_t *triple)
+{
+ PyObject *context;
+ PyObject *result;
+ uint32_t status = 0;
+
+ CURRENT_CONTEXT(context);
+
+ result = dec_alloc();
+ if (result == NULL) {
+ return NULL;
+ }
+
+ if (mpd_from_uint128_triple(MPD(result), triple, &status) < 0) {
+ if (dec_addstatus(context, status)) {
+ Py_DECREF(result);
+ return NULL;
+ }
+ }
+
+ return result;
+}
+
+/* Advanced API */
+static PyObject *
+PyDec_Alloc(void)
+{
+ return dec_alloc();
+}
+
+static mpd_t *
+PyDec_Get(PyObject *v)
+{
+ if (!PyDec_Check(v)) {
+ PyErr_SetString(PyExc_TypeError,
+ "PyDec_Get: argument must be a Decimal");
+ return NULL;
+ }
+
+ return MPD(v);
+}
+
+static const mpd_t *
+PyDec_GetConst(const PyObject *v)
+{
+ if (!PyDec_Check(v)) {
+ PyErr_SetString(PyExc_TypeError,
+ "PyDec_GetConst: argument must be a Decimal");
+ return NULL;
+ }
+
+ return MPD(v);
+}
+
+static PyObject *
+init_api(void)
+{
+ /* Simple API */
+ _decimal_api[PyDec_TypeCheck_INDEX] = (void *)PyDec_TypeCheck;
+ _decimal_api[PyDec_IsSpecial_INDEX] = (void *)PyDec_IsSpecial;
+ _decimal_api[PyDec_IsNaN_INDEX] = (void *)PyDec_IsNaN;
+ _decimal_api[PyDec_IsInfinite_INDEX] = (void *)PyDec_IsInfinite;
+ _decimal_api[PyDec_GetDigits_INDEX] = (void *)PyDec_GetDigits;
+ _decimal_api[PyDec_AsUint128Triple_INDEX] = (void *)PyDec_AsUint128Triple;
+ _decimal_api[PyDec_FromUint128Triple_INDEX] = (void *)PyDec_FromUint128Triple;
+
+ /* Advanced API */
+ _decimal_api[PyDec_Alloc_INDEX] = (void *)PyDec_Alloc;
+ _decimal_api[PyDec_Get_INDEX] = (void *)PyDec_Get;
+ _decimal_api[PyDec_GetConst_INDEX] = (void *)PyDec_GetConst;
+
+ return PyCapsule_New(_decimal_api, "_decimal._API", NULL);
+}
+
+
+/****************************************************************************/
+/* Module */
+/****************************************************************************/
+
static PyMethodDef _decimal_methods [] =
{
{ "getcontext", (PyCFunction)PyDec_GetCurrentContext, METH_NOARGS, doc_getcontext},
@@ -5665,17 +5821,27 @@ PyInit__decimal(void)
DecCondMap *cm;
struct ssize_constmap *ssize_cm;
struct int_constmap *int_cm;
+ static PyObject *capsule = NULL;
+ static int initialized = 0;
int i;
/* Init libmpdec */
- mpd_traphandler = dec_traphandler;
- mpd_mallocfunc = PyMem_Malloc;
- mpd_reallocfunc = PyMem_Realloc;
- mpd_callocfunc = mpd_callocfunc_em;
- mpd_free = PyMem_Free;
- mpd_setminalloc(_Py_DEC_MINALLOC);
+ if (!initialized) {
+ mpd_traphandler = dec_traphandler;
+ mpd_mallocfunc = PyMem_Malloc;
+ mpd_reallocfunc = PyMem_Realloc;
+ mpd_callocfunc = mpd_callocfunc_em;
+ mpd_free = PyMem_Free;
+ mpd_setminalloc(_Py_DEC_MINALLOC);
+
+ capsule = init_api();
+ if (capsule == NULL) {
+ return NULL;
+ }
+ initialized = 1;
+ }
/* Init external C-API functions */
_py_long_multiply = PyLong_Type.tp_as_number->nb_multiply;
@@ -5900,6 +6066,11 @@ PyInit__decimal(void)
CHECK_INT(PyModule_AddStringConstant(m, "__version__", "1.70"));
CHECK_INT(PyModule_AddStringConstant(m, "__libmpdec_version__", mpd_version()));
+ /* Add capsule API */
+ Py_INCREF(capsule);
+ if (PyModule_AddObject(m, "_API", capsule) < 0) {
+ goto error;
+ }
return m;
diff --git a/Modules/_decimal/tests/deccheck.py b/Modules/_decimal/tests/deccheck.py
index 5d9179e6168..15f104dc463 100644
--- a/Modules/_decimal/tests/deccheck.py
+++ b/Modules/_decimal/tests/deccheck.py
@@ -49,6 +49,9 @@ from randdec import unary_optarg, binary_optarg, ternary_optarg
from formathelper import rand_format, rand_locale
from _pydecimal import _dec_from_triple
+from _testcapi import decimal_as_triple
+from _testcapi import decimal_from_triple
+
C = import_fresh_module('decimal', fresh=['_decimal'])
P = import_fresh_module('decimal', blocked=['_decimal'])
EXIT_STATUS = 0
@@ -154,6 +157,45 @@ TernaryRestricted = ['__pow__', 'context.power']
# ======================================================================
+# Triple tests
+# ======================================================================
+
+def c_as_triple(dec):
+ sign, hi, lo, exp = decimal_as_triple(dec)
+
+ coeff = hi * 2**64 + lo
+ return (sign, coeff, exp)
+
+def c_from_triple(triple):
+ sign, coeff, exp = triple
+
+ hi = coeff // 2**64
+ lo = coeff % 2**64
+ return decimal_from_triple((sign, hi, lo, exp))
+
+def p_as_triple(dec):
+ sign, digits, exp = dec.as_tuple()
+
+ s = "".join(str(d) for d in digits)
+ coeff = int(s) if s else 0
+
+ if coeff < 0 or coeff >= 2**128:
+ raise ValueError("value out of bounds for a uint128 triple");
+
+ return (sign, coeff, exp)
+
+def p_from_triple(triple):
+ sign, coeff, exp = triple
+
+ if coeff < 0 or coeff >= 2**128:
+ raise ValueError("value out of bounds for a uint128 triple");
+
+ digits = tuple(int(c) for c in str(coeff))
+
+ return P.Decimal((sign, digits, exp))
+
+
+# ======================================================================
# Unified Context
# ======================================================================
@@ -846,12 +888,44 @@ def verify(t, stat):
t.presults.append(str(t.rp.imag))
t.presults.append(str(t.rp.real))
+ ctriple = None
+ if t.funcname not in ['__radd__', '__rmul__']: # see skip handler
+ try:
+ ctriple = c_as_triple(t.rc)
+ except ValueError:
+ try:
+ ptriple = p_as_triple(t.rp)
+ except ValueError:
+ pass
+ else:
+ raise RuntimeError("ValueError not raised")
+ else:
+ cres = c_from_triple(ctriple)
+ t.cresults.append(ctriple)
+ t.cresults.append(str(cres))
+
+ ptriple = p_as_triple(t.rp)
+ pres = p_from_triple(ptriple)
+ t.presults.append(ptriple)
+ t.presults.append(str(pres))
+
if t.with_maxcontext and isinstance(t.rmax, C.Decimal):
t.maxresults.append(t.rmax.to_eng_string())
t.maxresults.append(t.rmax.as_tuple())
t.maxresults.append(str(t.rmax.imag))
t.maxresults.append(str(t.rmax.real))
+ if ctriple is not None:
+ # NaN payloads etc. depend on precision and clamp.
+ if all_nan(t.rc) and all_nan(t.rmax):
+ t.maxresults.append(ctriple)
+ t.maxresults.append(str(cres))
+ else:
+ maxtriple = c_as_triple(t.rmax)
+ maxres = c_from_triple(maxtriple)
+ t.maxresults.append(maxtriple)
+ t.maxresults.append(str(maxres))
+
nc = t.rc.number_class().lstrip('+-s')
stat[nc] += 1
else:
diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c
index fca94a83a5d..593034ef65e 100644
--- a/Modules/_testcapimodule.c
+++ b/Modules/_testcapimodule.c
@@ -19,6 +19,7 @@
#include "Python.h"
#include "datetime.h"
+#include "pydecimal.h"
#include "marshal.h"
#include "structmember.h" // PyMemberDef
#include <float.h>
@@ -2705,6 +2706,252 @@ test_PyDateTime_DELTA_GET(PyObject *self, PyObject *obj)
return Py_BuildValue("(lll)", days, seconds, microseconds);
}
+/* Test decimal API */
+static int decimal_initialized = 0;
+static PyObject *
+decimal_is_special(PyObject *module, PyObject *dec)
+{
+ int is_special;
+
+ (void)module;
+ if (!decimal_initialized) {
+ if (import_decimal() < 0) {
+ return NULL;
+ }
+
+ decimal_initialized = 1;
+ }
+
+ is_special = PyDec_IsSpecial(dec);
+ if (is_special < 0) {
+ return NULL;
+ }
+
+ return PyBool_FromLong(is_special);
+}
+
+static PyObject *
+decimal_is_nan(PyObject *module, PyObject *dec)
+{
+ int is_nan;
+
+ (void)module;
+ if (!decimal_initialized) {
+ if (import_decimal() < 0) {
+ return NULL;
+ }
+
+ decimal_initialized = 1;
+ }
+
+ is_nan = PyDec_IsNaN(dec);
+ if (is_nan < 0) {
+ return NULL;
+ }
+
+ return PyBool_FromLong(is_nan);
+}
+
+static PyObject *
+decimal_is_infinite(PyObject *module, PyObject *dec)
+{
+ int is_infinite;
+
+ (void)module;
+ if (!decimal_initialized) {
+ if (import_decimal() < 0) {
+ return NULL;
+ }
+
+ decimal_initialized = 1;
+ }
+
+ is_infinite = PyDec_IsInfinite(dec);
+ if (is_infinite < 0) {
+ return NULL;
+ }
+
+ return PyBool_FromLong(is_infinite);
+}
+
+static PyObject *
+decimal_get_digits(PyObject *module, PyObject *dec)
+{
+ int64_t digits;
+
+ (void)module;
+ if (!decimal_initialized) {
+ if (import_decimal() < 0) {
+ return NULL;
+ }
+
+ decimal_initialized = 1;
+ }
+
+ digits = PyDec_GetDigits(dec);
+ if (digits < 0) {
+ return NULL;
+ }
+
+ return PyLong_FromLongLong(digits);
+}
+
+static PyObject *
+decimal_as_triple(PyObject *module, PyObject *dec)
+{
+ PyObject *tuple = NULL;
+ PyObject *sign, *hi, *lo;
+ mpd_uint128_triple_t triple;
+
+ (void)module;
+ if (!decimal_initialized) {
+ if (import_decimal() < 0) {
+ return NULL;
+ }
+
+ decimal_initialized = 1;
+ }
+
+ triple = PyDec_AsUint128Triple(dec);
+ if (triple.tag == MPD_TRIPLE_ERROR && PyErr_Occurred()) {
+ return NULL;
+ }
+
+ sign = PyLong_FromUnsignedLong(triple.sign);
+ if (sign == NULL) {
+ return NULL;
+ }
+
+ hi = PyLong_FromUnsignedLongLong(triple.hi);
+ if (hi == NULL) {
+ Py_DECREF(sign);
+ return NULL;
+ }
+
+ lo = PyLong_FromUnsignedLongLong(triple.lo);
+ if (lo == NULL) {
+ Py_DECREF(hi);
+ Py_DECREF(sign);
+ return NULL;
+ }
+
+ switch (triple.tag) {
+ case MPD_TRIPLE_QNAN:
+ assert(triple.exp == 0);
+ tuple = Py_BuildValue("(OOOs)", sign, hi, lo, "n");
+ break;
+
+ case MPD_TRIPLE_SNAN:
+ assert(triple.exp == 0);
+ tuple = Py_BuildValue("(OOOs)", sign, hi, lo, "N");
+ break;
+
+ case MPD_TRIPLE_INF:
+ assert(triple.hi == 0);
+ assert(triple.lo == 0);
+ assert(triple.exp == 0);
+ tuple = Py_BuildValue("(OOOs)", sign, hi, lo, "F");
+ break;
+
+ case MPD_TRIPLE_NORMAL:
+ tuple = Py_BuildValue("(OOOL)", sign, hi, lo, triple.exp);
+ break;
+
+ case MPD_TRIPLE_ERROR:
+ PyErr_SetString(PyExc_ValueError,
+ "value out of bounds for a uint128 triple");
+ break;
+
+ default:
+ PyErr_SetString(PyExc_RuntimeError,
+ "decimal_as_triple: internal error: unexpected tag");
+ break;
+ }
+
+ Py_DECREF(lo);
+ Py_DECREF(hi);
+ Py_DECREF(sign);
+
+ return tuple;
+}
+
+static PyObject *
+decimal_from_triple(PyObject *module, PyObject *tuple)
+{
+ mpd_uint128_triple_t triple = { MPD_TRIPLE_ERROR, 0, 0, 0, 0 };
+ PyObject *exp;
+ unsigned long sign;
+
+ (void)module;
+ if (!decimal_initialized) {
+ if (import_decimal() < 0) {
+ return NULL;
+ }
+
+ decimal_initialized = 1;
+ }
+
+ if (!PyTuple_Check(tuple)) {
+ PyErr_SetString(PyExc_TypeError, "argument must be a tuple");
+ return NULL;
+ }
+
+ if (PyTuple_GET_SIZE(tuple) != 4) {
+ PyErr_SetString(PyExc_ValueError, "tuple size must be 4");
+ return NULL;
+ }
+
+ sign = PyLong_AsUnsignedLong(PyTuple_GET_ITEM(tuple, 0));
+ if (sign == (unsigned long)-1 && PyErr_Occurred()) {
+ return NULL;
+ }
+ if (sign > UINT8_MAX) {
+ PyErr_SetString(PyExc_ValueError, "sign must be 0 or 1");
+ return NULL;
+ }
+ triple.sign = (uint8_t)sign;
+
+ triple.hi = PyLong_AsUnsignedLongLong(PyTuple_GET_ITEM(tuple, 1));
+ if (triple.hi == (unsigned long long)-1 && PyErr_Occurred()) {
+ return NULL;
+ }
+
+ triple.lo = PyLong_AsUnsignedLongLong(PyTuple_GET_ITEM(tuple, 2));
+ if (triple.lo == (unsigned long long)-1 && PyErr_Occurred()) {
+ return NULL;
+ }
+
+ exp = PyTuple_GET_ITEM(tuple, 3);
+ if (PyLong_Check(exp)) {
+ triple.tag = MPD_TRIPLE_NORMAL;
+ triple.exp = PyLong_AsLongLong(exp);
+ if (triple.exp == -1 && PyErr_Occurred()) {
+ return NULL;
+ }
+ }
+ else if (PyUnicode_Check(exp)) {
+ if (PyUnicode_CompareWithASCIIString(exp, "F") == 0) {
+ triple.tag = MPD_TRIPLE_INF;
+ }
+ else if (PyUnicode_CompareWithASCIIString(exp, "n") == 0) {
+ triple.tag = MPD_TRIPLE_QNAN;
+ }
+ else if (PyUnicode_CompareWithASCIIString(exp, "N") == 0) {
+ triple.tag = MPD_TRIPLE_SNAN;
+ }
+ else {
+ PyErr_SetString(PyExc_ValueError, "not a valid exponent");
+ return NULL;
+ }
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError, "exponent must be int or string");
+ return NULL;
+ }
+
+ return PyDec_FromUint128Triple(&triple);
+}
+
/* test_thread_state spawns a thread of its own, and that thread releases
* `thread_done` when it's finished. The driver code has to know when the
* thread finishes, because the thread uses a PyObject (the callable) that
@@ -5314,6 +5561,12 @@ static PyMethodDef TestMethods[] = {
{"PyDateTime_DATE_GET", test_PyDateTime_DATE_GET, METH_O},
{"PyDateTime_TIME_GET", test_PyDateTime_TIME_GET, METH_O},
{"PyDateTime_DELTA_GET", test_PyDateTime_DELTA_GET, METH_O},
+ {"decimal_is_special", decimal_is_special, METH_O},
+ {"decimal_is_nan", decimal_is_nan, METH_O},
+ {"decimal_is_infinite", decimal_is_infinite, METH_O},
+ {"decimal_get_digits", decimal_get_digits, METH_O},
+ {"decimal_as_triple", decimal_as_triple, METH_O},
+ {"decimal_from_triple", decimal_from_triple, METH_O},
{"test_list_api", test_list_api, METH_NOARGS},
{"test_dict_iteration", test_dict_iteration, METH_NOARGS},
{"dict_getitem_knownhash", dict_getitem_knownhash, METH_VARARGS},