Skip to content

Commit

Permalink
API: Array API support - Part 1 (#612)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol authored Dec 21, 2023
1 parent 0e283ff commit 01d8934
Show file tree
Hide file tree
Showing 11 changed files with 591 additions and 21 deletions.
76 changes: 76 additions & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,79 @@

__version__ = get_versions()["version"]
del get_versions

from numpy import (
bool_ as bool,
float16,
float32,
float64,
complex64,
complex128,
uint8,
uint16,
uint32,
uint64,
int8,
int16,
int32,
int64,
pi,
e,
nan,
inf,
newaxis,
sin,
sinh,
cos,
cosh,
tan,
tanh,
arcsin as asin,
arcsinh as asinh,
arccos as acos,
arccosh as acosh,
arctan as atan,
arctan2 as atan2,
arctanh as atanh,
log,
log2,
log1p,
log10,
logaddexp,
power as pow,
sign,
square,
sqrt,
logical_and,
logical_not,
logical_or,
logical_xor,
bitwise_and,
bitwise_or,
bitwise_xor,
bitwise_not,
trunc,
add,
subtract,
remainder,
positive,
not_equal,
negative,
multiply,
less_equal,
less,
greater_equal,
greater,
floor_divide,
floor,
exp,
expm1,
divide,
ceil,
left_shift as bitwise_left_shift,
right_shift as bitwise_right_shift,
invert as bitwise_invert,
finfo,
iinfo,
can_cast,
)
226 changes: 211 additions & 15 deletions sparse/_common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import numpy as np
import numba
import scipy.sparse
import builtins
from collections.abc import Iterable
from functools import wraps, reduce
from itertools import chain
from operator import mul, index
from collections.abc import Iterable
import warnings

import numpy as np
import numba
import scipy.sparse
from scipy.sparse import spmatrix
from numba import literal_unroll
import warnings

from ._sparse_array import SparseArray
from ._utils import (
Expand All @@ -33,6 +35,8 @@
roll,
kron,
argwhere,
argmax,
argmin,
isposinf,
isneginf,
result_type,
Expand Down Expand Up @@ -187,7 +191,7 @@ def tensordot(a, b, axes=2, *, return_type=None):
newshape_b = (N2, -1)
oldb = [bs[axis] for axis in notin]

if any(dim == 0 for dim in chain(newshape_a, newshape_b)):
if builtins.any(dim == 0 for dim in chain(newshape_a, newshape_b)):
res = asCOO(np.empty(olda + oldb), check=False)
if isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
res = res.todense()
Expand Down Expand Up @@ -268,12 +272,12 @@ def _matmul_recurser(a, b):
if a.ndim == 2:
return dot(a, b)
res = []
for i in range(max(a.shape[0], b.shape[0])):
for i in range(builtins.max(a.shape[0], b.shape[0])):
a_i = a[0] if a.shape[0] == 1 else a[i]
b_i = b[0] if b.shape[0] == 1 else b[i]
res.append(_matmul_recurser(a_i, b_i))
mask = [isinstance(x, SparseArray) for x in res]
if all(mask):
if builtins.all(mask):
return stack(res)
else:
res = [x.todense() if isinstance(x, SparseArray) else x for x in res]
Expand Down Expand Up @@ -334,7 +338,7 @@ def _dot(a, b, return_type=None):
from ._sparse_array import SparseArray

out_shape = (a.shape[0], b.shape[1])
if all(isinstance(arr, SparseArray) for arr in [a, b]) and any(
if builtins.all(isinstance(arr, SparseArray) for arr in [a, b]) and builtins.any(
isinstance(arr, GCXS) for arr in [a, b]
):
a = a.asformat("gcxs")
Expand Down Expand Up @@ -1333,7 +1337,7 @@ def _parse_einsum_input(operands):
if operands[num].shape == ():
ellipse_count = 0
else:
ellipse_count = max(operands[num].ndim, 1)
ellipse_count = builtins.max(operands[num].ndim, 1)
ellipse_count -= len(sub) - 3

if ellipse_count > longest:
Expand Down Expand Up @@ -1573,7 +1577,7 @@ def stack(arrays, axis=0, compressed_axes=None):
"""
from ._compressed import GCXS

if not all(isinstance(arr, GCXS) for arr in arrays):
if not builtins.all(isinstance(arr, GCXS) for arr in arrays):
from ._coo import stack as coo_stack

return coo_stack(arrays, axis)
Expand Down Expand Up @@ -1612,7 +1616,7 @@ def concatenate(arrays, axis=0, compressed_axes=None):
"""
from ._compressed import GCXS

if not all(isinstance(arr, GCXS) for arr in arrays):
if not builtins.all(isinstance(arr, GCXS) for arr in arrays):
from ._coo import concatenate as coo_concat

return coo_concat(arrays, axis)
Expand All @@ -1622,6 +1626,9 @@ def concatenate(arrays, axis=0, compressed_axes=None):
return gcxs_concat(arrays, axis, compressed_axes)


concat = concatenate


def eye(N, M=None, k=0, dtype=float, format="coo", **kwargs):
"""Return a 2-D array in the specified format with ones on the diagonal and zeros elsewhere.
Expand Down Expand Up @@ -1665,14 +1672,14 @@ def eye(N, M=None, k=0, dtype=float, format="coo", **kwargs):
M = int(M)
k = int(k)

data_length = min(N, M)
data_length = builtins.min(N, M)

if k > 0:
data_length = max(min(data_length, M - k), 0)
data_length = builtins.max(builtins.min(data_length, M - k), 0)
n_coords = np.arange(data_length, dtype=np.intp)
m_coords = n_coords + k
elif k < 0:
data_length = max(min(data_length, N + k), 0)
data_length = builtins.max(builtins.min(data_length, N + k), 0)
m_coords = np.arange(data_length, dtype=np.intp)
n_coords = m_coords - k
else:
Expand Down Expand Up @@ -1905,6 +1912,20 @@ def ones_like(a, dtype=None, shape=None, format=None, **kwargs):
return full_like(a, 1, dtype=dtype, shape=shape, format=format, **kwargs)


def empty(shape, dtype=float, format="coo", **kwargs):
return full(shape, 0, np.dtype(dtype)).asformat(format, **kwargs)


empty.__doc__ = zeros.__doc__


def empty_like(a, dtype=None, shape=None, format=None, **kwargs):
return full_like(a, 0, dtype=dtype, shape=shape, format=format, **kwargs)


empty_like.__doc__ = zeros_like.__doc__


def outer(a, b, out=None):
"""
Return outer product of two sparse arrays.
Expand Down Expand Up @@ -2088,3 +2109,178 @@ def format_to_string(format):
return format

raise ValueError(f"invalid format: {format}")


def asarray(
obj, /, *, dtype=None, format="coo", backend="pydata", device=None, copy=False
):
"""
Convert the input to a sparse array.
Parameters
----------
obj : array_like
Object to be converted to an array.
dtype : dtype, optional
Output array data type.
format : str, optional
Output array sparse format.
backend : str, optional
Backend for the output array.
device : str, optional
Device on which to place the created array.
copy : bool, optional
Boolean indicating whether or not to copy the input.
Returns
-------
out : Union[SparseArray, numpy.ndarray]
Sparse or 0-D array containing the data from `obj`.
Examples
--------
>>> x = np.eye(8, dtype='i8')
>>> sparse.asarray(x, format="COO")
<COO: shape=(8, 8), dtype=int64, nnz=8, fill_value=0>
"""
if format not in ["coo", "dok", "gcxs"]:
raise ValueError(f"{format} format not supported.")

if backend not in ["pydata", "taco"]:
raise ValueError(f"{backend} backend not supported.")

from ._coo import COO
from ._dok import DOK
from ._compressed import GCXS

format_dict = {"coo": COO, "dok": DOK, "gcxs": GCXS}

if backend == "pydata":
if isinstance(obj, (COO, DOK, GCXS)):
# TODO: consider `format` argument
warnings.warn("`format` argument was ignored")
return obj

elif isinstance(obj, spmatrix):
return format_dict[format].from_scipy_sparse(
obj.astype(dtype=dtype, copy=copy)
)

# check for scalars and 0-D arrays
elif np.isscalar(obj) or (isinstance(obj, np.ndarray) and obj.shape == ()):
return np.asarray(obj, dtype=dtype)

elif isinstance(obj, np.ndarray):
return format_dict[format].from_numpy(obj).astype(dtype=dtype, copy=copy)

else:
raise ValueError(f"{type(obj)} not supported.")

elif backend == "taco":
raise ValueError("Taco not yet supported.")


def _support_numpy(func):
"""
In case a NumPy array is passed to `sparse` namespace function
we want to flag it and dispatch to NumPy.
"""

def wrapper_func(*args, **kwargs):
x = args[0]
if isinstance(x, (np.ndarray, np.number)):
warnings.warn(
f"Sparse {func.__name__} received dense NumPy array instead "
"of sparse array. Dispatching to NumPy function."
)
return getattr(np, func.__name__)(*args, **kwargs)
else:
return func(*args, **kwargs)

return wrapper_func


def all(x, /, *, axis=None, keepdims=False):
return x.all(axis=axis, keepdims=keepdims)


def any(x, /, *, axis=None, keepdims=False):
return x.any(axis=axis, keepdims=keepdims)


def permute_dims(x, /, axes=None):
return x.transpose(axes=axes)


def max(x, /, *, axis=None, keepdims=False):
return x.max(axis=axis, keepdims=keepdims)


def mean(x, /, *, axis=None, keepdims=False, dtype=None):
return x.mean(axis=axis, keepdims=keepdims, dtype=dtype)


def min(x, /, *, axis=None, keepdims=False):
return x.min(axis=axis, keepdims=keepdims)


def prod(x, /, *, axis=None, dtype=None, keepdims=False):
return x.prod(axis=axis, keepdims=keepdims, dtype=dtype)


def std(x, /, *, axis=None, correction=0.0, keepdims=False):
return x.std(axis=axis, ddof=correction, keepdims=keepdims)


def sum(x, /, *, axis=None, dtype=None, keepdims=False):
return x.sum(axis=axis, keepdims=keepdims, dtype=dtype)


def var(x, /, *, axis=None, correction=0.0, keepdims=False):
return x.var(axis=axis, ddof=correction, keepdims=keepdims)


def abs(x, /):
return x.__abs__()


def reshape(x, /, shape, *, copy=None):
return x.reshape(shape=shape)


def astype(x, dtype, /, *, copy=True):
return x.astype(dtype, copy=copy)


@_support_numpy
def broadcast_to(x, /, shape):
return x.broadcast_to(shape)


def broadcast_arrays(*arrays):
shape = np.broadcast_shapes(*[a.shape for a in arrays])
return [a.broadcast_to(shape) for a in arrays]


def equal(x1, x2, /):
return x1 == x2


@_support_numpy
def round(x, /, decimals=0, out=None):
return x.round(decimals=decimals, out=out)


@_support_numpy
def isinf(x, /):
return x.isinf()


@_support_numpy
def isnan(x, /):
return x.isnan()


def isfinite(x, /):
return ~isinf(x)


def nonzero(x, /):
return x.nonzero()
Loading

0 comments on commit 01d8934

Please sign in to comment.