1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475 | def einsum(*operands, **kwargs):
"""
Perform the equivalent of [`numpy.einsum`][].
Parameters
----------
subscripts : str
Specifies the subscripts for summation as comma separated list of
subscript labels. An implicit (classical Einstein summation)
calculation is performed unless the explicit indicator '->' is
included as well as subscript labels of the precise output form.
operands : sequence of SparseArray
These are the arrays for the operation.
dtype : data-type, optional
If provided, forces the calculation to use the data type specified.
Default is `None`.
**kwargs : dict, optional
Any additional arguments to pass to the function.
Returns
-------
output : SparseArray
The calculation based on the Einstein summation convention.
"""
lhs, rhs, operands = _parse_einsum_input(operands) # Parse input
check_zero_fill_value(*operands)
if "dtype" in kwargs and kwargs["dtype"] is not None:
operands = [o.astype(kwargs["dtype"]) for o in operands]
if len(operands) == 1:
return _einsum_single(lhs, rhs, operands[0])
# if multiple arrays: align, broadcast multiply and then use single einsum
# for example:
# "aab,cbd->dac"
# we first perform single term reductions and align:
# aab -> ab..
# cbd -> .bcd
# (where dots represent broadcastable size 1 dimensions), then multiply all
# to form the 'minimal outer product' and do a final single term einsum:
# abcd -> dac
# get ordered union of indices from all terms, indicies that only appear
# on a single term will be removed in the 'preparation' step below
terms = lhs.split(",")
total = {}
sizes = {}
for t, term in enumerate(terms):
shape = operands[t].shape
for ix, d in zip(term, shape, strict=False):
if d != sizes.setdefault(ix, d):
raise ValueError(f"Inconsistent shape for index '{ix}'.")
total.setdefault(ix, set()).add(t)
for ix in rhs:
total[ix].add(-1)
aligned_term = "".join(ix for ix, apps in total.items() if len(apps) > 1)
# NB: if every index appears exactly twice,
# we could identify and dispatch to tensordot here?
parrays = []
for term, array in zip(terms, operands, strict=True):
# calc the target indices for this term
pterm = "".join(ix for ix in aligned_term if ix in term)
if pterm != term:
# perform necessary transpose and reductions
array = _einsum_single(term, pterm, array)
# calc broadcastable shape
shape = tuple(array.shape[pterm.index(ix)] if ix in pterm else 1 for ix in aligned_term)
parrays.append(array.reshape(shape) if array.shape != shape else array)
aligned_array = reduce(mul, parrays)
return _einsum_single(aligned_term, rhs, aligned_array)
|