跳到内容

einsum

执行与 numpy.einsum 等效的操作。

参数

名称 类型 描述 默认值
subscripts str

指定求和的下标,格式为逗号分隔的下标标签列表。除非包含显式指示符‘->’以及精确输出形式的下标标签,否则执行隐式(经典爱因斯坦求和)计算。

必需
operands SparseArray 序列

这些是操作所需的数组。

()
dtype 数据 - 类型

如果提供,强制计算使用指定的数据类型。默认为 None

必需
**kwargs dict

要传递给函数的任何额外参数。

{}

返回值

名称 类型 描述
output SparseArray

基于爱因斯坦求和约定的计算。

源代码位于 sparse/numba_backend/_common.py
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
def einsum(*operands, **kwargs):
    """
    Perform the equivalent of [`numpy.einsum`][].

    Parameters
    ----------
    subscripts : str
        Specifies the subscripts for summation as comma separated list of
        subscript labels. An implicit (classical Einstein summation)
        calculation is performed unless the explicit indicator '->' is
        included as well as subscript labels of the precise output form.
    operands : sequence of SparseArray
        These are the arrays for the operation.
    dtype : data-type, optional
        If provided, forces the calculation to use the data type specified.
        Default is `None`.
    **kwargs : dict, optional
        Any additional arguments to pass to the function.

    Returns
    -------
    output : SparseArray
        The calculation based on the Einstein summation convention.
    """

    lhs, rhs, operands = _parse_einsum_input(operands)  # Parse input

    check_zero_fill_value(*operands)

    if "dtype" in kwargs and kwargs["dtype"] is not None:
        operands = [o.astype(kwargs["dtype"]) for o in operands]

    if len(operands) == 1:
        return _einsum_single(lhs, rhs, operands[0])

    # if multiple arrays: align, broadcast multiply and then use single einsum
    # for example:
    #     "aab,cbd->dac"
    # we first perform single term reductions and align:
    #     aab -> ab..
    #     cbd -> .bcd
    # (where dots represent broadcastable size 1 dimensions), then multiply all
    # to form the 'minimal outer product' and do a final single term einsum:
    #     abcd -> dac

    # get ordered union of indices from all terms, indicies that only appear
    # on a single term will be removed in the 'preparation' step below
    terms = lhs.split(",")
    total = {}
    sizes = {}
    for t, term in enumerate(terms):
        shape = operands[t].shape
        for ix, d in zip(term, shape, strict=False):
            if d != sizes.setdefault(ix, d):
                raise ValueError(f"Inconsistent shape for index '{ix}'.")
            total.setdefault(ix, set()).add(t)
    for ix in rhs:
        total[ix].add(-1)
    aligned_term = "".join(ix for ix, apps in total.items() if len(apps) > 1)

    # NB: if every index appears exactly twice,
    # we could identify and dispatch to tensordot here?

    parrays = []
    for term, array in zip(terms, operands, strict=True):
        # calc the target indices for this term
        pterm = "".join(ix for ix in aligned_term if ix in term)
        if pterm != term:
            # perform necessary transpose and reductions
            array = _einsum_single(term, pterm, array)
        # calc broadcastable shape
        shape = tuple(array.shape[pterm.index(ix)] if ix in pterm else 1 for ix in aligned_term)
        parrays.append(array.reshape(shape) if array.shape != shape else array)

    aligned_array = reduce(mul, parrays)

    return _einsum_single(aligned_term, rhs, aligned_array)