跳到内容

take

返回数组沿轴的元素。

参数

名称 类型 描述 默认值
x SparseArray

输入数组。

必需
indices ndarray

数组索引。该数组必须是一维的,并且具有整数数据类型。

必需
axis int

用于选择值的轴。如果axis为负数,函数必须通过从最后一个维度开始计数来确定选择值的轴。如果为None,则使用展平的输入数组。默认值:None

None

返回值

名称 类型 描述
out COO

一个包含请求索引的COO数组。

引发

类型 描述
ValueError

如果输入数组不是COO格式且无法转换为COO格式。

源代码位于 sparse/numba_backend/_coo/common.py
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
def take(x, indices, /, *, axis=None):
    """
    Returns elements of an array along an axis.

    Parameters
    ----------
    x : SparseArray
        Input array.
    indices : ndarray
        Array indices. The array must be one-dimensional and have an integer data type.
    axis : int
        Axis over which to select values. If ``axis`` is negative, the function must
        determine the axis along which to select values by counting from the last dimension.
        For ``None``, the flattened input array is used. Default: ``None``.

    Returns
    -------
    out : COO
        A COO array with requested indices.

    Raises
    ------
    ValueError
        If the input array isn't and can't be converted to COO format.
    """

    x = _validate_coo_input(x)

    if axis is None:
        x = x.flatten()
        return x[indices]

    axis = normalize_axis(axis, x.ndim)
    full_index = (slice(None),) * axis + (indices, ...)
    return x[full_index]