Source code for teneva_ht_jax.tensors

"""teneva_ht_jax.tensors: various useful HT-tensors.

This module contains the collection of functions for explicit construction of
various useful HT-tensors (only random tensor for now).

"""
import jax
import jax.numpy as jnp


[docs]def rand(d, n, r, key, a=-1., b=1.): """Construct a random HT-tensor from the uniform distribution. Args: d (int): number of tensor dimensions. n (int): mode size of the tensor. r (int): TT-ranks of the tensor. It should be a number (if all ranks are equal) or list of the length q-1, where q is a number of levels. key (jax.random.PRNGKey): jax random key. a (float): minimum value for random items of the HT-cores. b (float): maximum value for random items of the HT-cores. Returns: list: HT-tensor. """ q = d.bit_length() # Full number of levels (e.g., d=8 -> q=4) if isinstance(r, int): r = [r] * (q-1) if len(r) != (q-1): raise ValueError('Invalid length of ranks list') Y = [] def _rand_level(key, sh): key, key_cur = jax.random.split(key) Yl = jax.random.uniform(key_cur, sh, minval=a, maxval=b) return Yl, key # Build the first level (leafs): Yl, key = _rand_level(key, sh=(d, n, r[0])) Y.append(Yl) # 3D tensor (len, n, r_up) # Build the inner levels: for k in range(1, q-1): dl = 2**(q-k-1) # Length of the current layer Yl, key = _rand_level(key, sh=(dl, r[k-1], r[k], r[k-1])) Y.append(Yl) # 4D tensor (len, r_down, r_up, r_down) # Build the last level (root): Yl, key = _rand_level(key, sh=(r[-1], r[-1])) Y.append(Yl) # 2D tensor (r_down, r_down) return Y