Module gd: gradient descent for HT-tensors

teneva_ht_jax.gd: build HT-tensor with gradient descent.

This module contains the function “gd_appr”, which computes the HT-tensor by gradient descent method using the provided train dataset.




teneva_ht_jax.gd.gd_appr(Y, I_trn, y_trn, I_vld=None, y_vld=None, epochs=100, batch=100, lr=0.0001, seed=42, log=True, log_trn=True)[source]

Build HT-tensor with gradient descent.

Parameters:
  • Y (list) – HT-tensor, which is the initial approximation for algorithm.

  • I_trn (np.ndarray) – multi-indices for the tensor in the form of array of the shape [samples, d], where d is a number of tensor’s dimensions and samples is a size of the train dataset.

  • y_trn (np.ndarray) – values of the tensor for multi-indices I_trn in the form of array of the shape [samples].

  • I_vld (np.ndarray) – optional multi-indices for items of validation dataset in the form of array of the shape [samples_vld, d], where samples_vld is a size of the validation dataset.

  • y_vld (np.ndarray) – optional values of the tensor for multi-indices I_vld of validation dataset in the form of array of the shape [samples_vld].

  • epochs (int) – number of train epochs.

  • batch (int) – size of the batch for train.

  • lr (float) – learning rate for training.

  • key (jax.random.PRNGKey) – jax random key.

  • log (bool) – if flag is set, then the information about the progress of the algorithm will be printed after each epoch.

  • log_trn (bool) – if flag is set (and log is also set), then the accuracy on the train dataset will be presented after each epoch (note that for large training datasets this can take a significant amount of time, and it may be better to turn off this flag).

Returns:

HT-tensor, which is the constructed approximation.

Return type:

list

Examples:

First, we set the shape and ranks of the tensor:

d = 8          # Dimension of the tensor
n = 10         # Mode size for the tensor
r = [9, 7, 5]  # Ranks for tree layers

We set the target (discretized) function, for which we will try to build the HT-approximation:

def func_build(d, n):
    """Ackley function. See https://www.sfu.ca/~ssurjano/ackley.html."""

    a = -32.768         # Grid lower bound
    b = +32.768         # Grid upper bound

    par_a = 20.         # Standard parameter values for Ackley function
    par_b = 0.2
    par_c = 2.*jnp.pi

    def func(I):
        """Target function: y=f(I); [samples,d] -> [samples]."""
        X = I / (n - 1) * (b - a) + a

        y1 = jnp.sqrt(jnp.sum(X**2, axis=1) / d)
        y1 = - par_a * jnp.exp(-par_b * y1)

        y2 = jnp.sum(jnp.cos(par_c * X), axis=1)
        y2 = - jnp.exp(y2 / d)

        y3 = par_a + jnp.exp(1.)

        return y1 + y2 + y3

    return func

func = func_build(d, n)

Then we generate train and validation data:

m_trn = 1.E+5 # Number of train items

rng, key = jax.random.split(rng)
I_trn = jax.random.choice(key, np.arange(n), (int(m_trn), d), replace=True)
y_trn = func(I_trn)
m_vld = 1.E+3 # Number of validation items

rng, key = jax.random.split(rng)
I_vld = jax.random.choice(key, np.arange(n), (int(m_trn), d), replace=True)
y_vld = func(I_vld)

Let build the random HT-tensor (initial approximation):

rng, key = jax.random.split(rng)
Y = tnv.rand(d, n, r, key)

Next we set the parameters of the gradient descent method:

epochs = 20     # Number of train epochs
batch  = 100    # Size of the batch for train
lr     = 1.E-2  # Learning rate for training

And now we can construct the HT-approximation using gradient descent:

Y = tnv.gd_appr(Y, I_trn, y_trn, I_vld, y_vld, epochs, batch, lr, log=True)

# >>> ----------------------------------------
# >>> Output:

# #   1 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :      2.61
# #   2 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :      3.78
# #   3 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :      4.98
# #   4 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :      6.18
# #   5 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :      7.39
# #   6 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :      8.59
# #   7 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :      9.87
# #   8 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :     11.06
# #   9 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :     12.20
# #  10 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :     13.34
# #  11 | e_trn : 1.0e+00 | e_vld : 1.0e+00 | t :     14.47
# #  12 | e_trn : 2.2e-02 | e_vld : 2.2e-02 | t :     15.60
# #  13 | e_trn : 3.5e-02 | e_vld : 3.5e-02 | t :     16.73
# #  14 | e_trn : 2.5e-02 | e_vld : 2.5e-02 | t :     17.86
# #  15 | e_trn : 2.1e-02 | e_vld : 2.1e-02 | t :     19.02
# #  16 | e_trn : 1.8e-02 | e_vld : 1.8e-02 | t :     20.17
# #  17 | e_trn : 2.0e-02 | e_vld : 2.0e-02 | t :     21.29
# #  18 | e_trn : 1.2e-02 | e_vld : 1.2e-02 | t :     22.41
# #  19 | e_trn : 1.4e-02 | e_vld : 1.4e-02 | t :     23.57
# #  20 | e_trn : 1.3e-02 | e_vld : 1.3e-02 | t :     24.73
#

Let select some tensor element and compute the value from the constracted approximation:

k = np.array([0, 1, 2, 3, 4, 5, 6, 7])
y = tnv.get(Y, k)
print(y)

# >>> ----------------------------------------
# >>> Output:

# 21.239890172224758
#

We can compute the same element from the target function:

y_real = func(k.reshape(1, -1))[0]
print(y_real)

# >>> ----------------------------------------
# >>> Output:

# 21.170489539765892
#
# Let compare approximated and exact values:
e = np.abs(y - y_real)
print(f'Error : {e:7.1e}')

# >>> ----------------------------------------
# >>> Output:

# Error : 6.9e-02
#