跳到内容

循环移位

沿着指定轴移动数组元素。移出末尾位置的元素将被循环并重新引入到起始位置。

参数

名称 类型 描述 默认值
a COO

输入数组

必需
shift 整型或整数元组

元素移位的索引位置数量。如果提供元组,则轴(axis)也必须是相同大小的元组,并且每个给定轴按相应数量进行移位。如果轴(axis)是一个整数元组,而移位量(shift)是一个整数,则会使用广播,从而将相同的移位应用于所有轴。

必需
axis 整型或整数元组

指定单个或多个轴的轴或元组。默认情况下,数组在移位前会被展平,之后恢复原始形状。

None

返回值

名称 类型 描述
res ndarray

输出数组,形状与 a 相同。

源代码位于 sparse/numba_backend/_coo/common.py
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
def roll(a, shift, axis=None):
    """
    Shifts elements of an array along specified axis. Elements that roll beyond
    the last position are circulated and re-introduced at the first.

    Parameters
    ----------
    a : COO
        Input array
    shift : int or tuple of ints
        Number of index positions that elements are shifted. If a tuple is
        provided, then axis must be a tuple of the same size, and each of the
        given axes is shifted by the corresponding number. If an int while axis
        is a tuple of ints, then broadcasting is used so the same shift is
        applied to all axes.
    axis : int or tuple of ints, optional
        Axis or tuple specifying multiple axes. By default, the
        array is flattened before shifting, after which the original shape is
        restored.

    Returns
    -------
    res : ndarray
        Output array, with the same shape as a.
    """
    from .core import COO, as_coo

    a = as_coo(a)

    # roll flattened array
    if axis is None:
        return roll(a.reshape((-1,)), shift, 0).reshape(a.shape)

    # roll across specified axis
    # parse axis input, wrap in tuple
    axis = normalize_axis(axis, a.ndim)
    if not isinstance(axis, tuple):
        axis = (axis,)

    # make shift iterable
    if not isinstance(shift, Iterable):
        shift = (shift,)

    elif np.ndim(shift) > 1:
        raise ValueError("'shift' and 'axis' must be integers or 1D sequences.")

    # handle broadcasting
    if len(shift) == 1:
        shift = np.full(len(axis), shift)

    # check if dimensions are consistent
    if len(axis) != len(shift):
        raise ValueError("If 'shift' is a 1D sequence, 'axis' must have equal length.")

    if not can_store(a.coords.dtype, max(a.shape + shift)):
        raise ValueError(
            f"cannot roll with coords.dtype {a.coords.dtype} and shift {shift}. Try casting coords to a larger dtype."
        )

    # shift elements
    coords, data = np.copy(a.coords), np.copy(a.data)
    try:
        for sh, ax in zip(shift, axis, strict=True):
            coords[ax] += sh
            coords[ax] %= a.shape[ax]
    except TypeError as e:
        if is_unsigned_dtype(coords.dtype):
            raise ValueError(
                f"rolling with coords.dtype as {coords.dtype} is not safe. Try using a signed dtype."
            ) from e

    return COO(
        coords,
        data=data,
        shape=a.shape,
        has_duplicates=False,
        fill_value=a.fill_value,
    )