Source code for teneva_ht_jax.vis

"""teneva_ht_jax.vis: visualization methods for tensors.

This module contains the functions for visualization of HT-tensors.

"""
import jax.numpy as jnp


[docs]def show(Y): """Display mode size and ranks of the given HT-tensor. Args: Y (list): HT-tensor. Todo: Add more accurate and informative visualization. """ if not isinstance(Y, list): raise ValueError('Invalid HT-tensor') d = Y[0].shape[0] print(f'HT-tensor | d={d:-3d}') for q, Yl in enumerate(Y): print(f'Level: {q+1:-3d} | Shape: {Yl.shape}')