跳到内容

dot

在两个数组上执行与 numpy.dot 等效的操作。

参数

名称 类型 描述 默认值
a 联合[SparseArray, ndarray, spmatrix]

要执行 dot 操作的数组。

必需
b 联合[SparseArray, ndarray, spmatrix]

要执行 dot 操作的数组。

必需

返回值

类型 描述
联合[SparseArray, ndarray]

操作的结果。

引发

类型 描述
ValueError

如果所有参数都没有零填充值。

另请参阅
源代码位于 sparse/numba_backend/_common.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def dot(a, b):
    """
    Perform the equivalent of [`numpy.dot`][] on two arrays.

    Parameters
    ----------
    a, b : Union[SparseArray, np.ndarray, scipy.sparse.spmatrix]
        The arrays to perform the `dot` operation on.

    Returns
    -------
    Union[SparseArray, numpy.ndarray]
        The result of the operation.

    Raises
    ------
    ValueError
        If all arguments don't have zero fill-values.

    See Also
    --------
    - [`numpy.dot`][] : NumPy equivalent function.
    - [`sparse.COO.dot`][] : Equivalent function for COO objects.
    """
    check_zero_fill_value(a, b)
    if not hasattr(a, "ndim") or not hasattr(b, "ndim"):
        raise TypeError(f"Cannot perform dot product on types {type(a)}, {type(b)}")

    if a.ndim == 1 and b.ndim == 1:
        if isinstance(a, SparseArray):
            a = as_coo(a)
        if isinstance(b, SparseArray):
            b = as_coo(b)
        return (a * b).sum()

    a_axis = -1
    b_axis = -2

    if b.ndim == 1:
        b_axis = -1
    return tensordot(a, b, axes=(a_axis, b_axis))