Source code for teneva_ht_jax.act_one
"""teneva_ht_jax.act_one: single HT-tensor operations.
This module contains the basic operations with one HT-tensor (Y), including
"get", etc.
"""
import jax
import jax.numpy as jnp
[docs]def get(Y, k):
"""Compute the element of the HT-tensor.
Args:
Y (list): d-dimensional HT-tensor.
k (np.ndarray): the multi-index for the tensor of the length d.
Returns:
float: the element of the HT-tensor.
"""
def body_leaf(q, data):
i1, i2, G1, G2, G = data
q = jnp.einsum('r,q,rsq->s', G1[i1], G2[i2], G)
return None, q
def body(q, data):
g1, g2, G = data
q = jnp.einsum('r,q,rsq->s', g1, g2, G)
return None, q
# Compute for the first level (leafs):
_, q = jax.lax.scan(body_leaf, None,
(k[0::2], k[1::2], Y[0][0::2], Y[0][1::2], Y[1]))
# Compute for the inner levels:
for k in range(1, len(Y)-2):
_, q = jax.lax.scan(body, None, (q[0::2], q[1::2], Y[k+1]))
# Compute for the last level (root):
q = jnp.einsum('r,q,rq->', q[0], q[1], Y[-1])
return q