diff --git a/python/atoms/prod.h b/python/atoms/prod.h new file mode 100644 index 00000000..9e158186 --- /dev/null +++ b/python/atoms/prod.h @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +#ifndef ATOM_PROD_H +#define ATOM_PROD_H + +#include "common.h" +#include "other.h" + +static PyObject *py_make_prod(PyObject *self, PyObject *args) +{ + (void) self; + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_prod(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create prod node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_PROD_H */ diff --git a/python/bindings.c b/python/bindings.c index 97c2906b..15ca7326 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -21,6 +21,7 @@ #include "atoms/multiply.h" #include "atoms/neg.h" #include "atoms/power.h" +#include "atoms/prod.h" #include "atoms/promote.h" #include "atoms/quad_form.h" #include "atoms/quad_over_lin.h" @@ -72,6 +73,7 @@ static PyMethodDef DNLPMethods[] = { {"make_const_vector_mult", py_make_const_vector_mult, METH_VARARGS, "Create constant vector multiplication node (a ∘ f(x))"}, {"make_power", py_make_power, METH_VARARGS, "Create power node"}, + {"make_prod", py_make_prod, METH_VARARGS, "Create prod node"}, {"make_sin", py_make_sin, METH_VARARGS, "Create sin node"}, {"make_cos", py_make_cos, METH_VARARGS, "Create cos node"}, {"make_tan", py_make_tan, METH_VARARGS, "Create tan node"}, diff --git a/src/dnlp_diff_engine/__init__.py b/src/dnlp_diff_engine/__init__.py index c8806845..4b8efb14 100644 --- a/src/dnlp_diff_engine/__init__.py +++ b/src/dnlp_diff_engine/__init__.py @@ -226,6 +226,8 @@ def _convert_reshape(expr, children): children[0], _extract_flat_indices_from_special_index(expr) ), "reshape": _convert_reshape, + # Reductions returning scalar + "Prod": lambda _expr, children: _diffengine.make_prod(children[0]), }