Source code for teneva_ht_jax.gd
"""teneva_ht_jax.gd: build HT-tensor with gradient descent.
This module contains the function "gd_appr", which computes the HT-tensor by
gradient descent method using the provided train dataset.
"""
import jax
import jax.numpy as jnp
import optax
import teneva_ht_jax as tnv
from time import perf_counter as tpc
[docs]def gd_appr(Y, I_trn, y_trn, I_vld=None, y_vld=None, epochs=100, batch=100,
lr=1.E-4, seed=42, log=True, log_trn=True):
"""Build HT-tensor with gradient descent.
Args:
Y (list): HT-tensor, which is the initial approximation for algorithm.
I_trn (np.ndarray): multi-indices for the tensor in the form of array
of the shape [samples, d], where d is a number of tensor's
dimensions and samples is a size of the train dataset.
y_trn (np.ndarray): values of the tensor for multi-indices I_trn in
the form of array of the shape [samples].
I_vld (np.ndarray): optional multi-indices for items of validation
dataset in the form of array of the shape [samples_vld, d], where
samples_vld is a size of the validation dataset.
y_vld (np.ndarray): optional values of the tensor for multi-indices
I_vld of validation dataset in the form of array of the shape
[samples_vld].
epochs (int): number of train epochs.
batch (int): size of the batch for train.
lr (float): learning rate for training.
key (jax.random.PRNGKey): jax random key.
log (bool): if flag is set, then the information about the progress of
the algorithm will be printed after each epoch.
log_trn (bool): if flag is set (and log is also set), then the
accuracy on the train dataset will be presented after each epoch
(note that for large training datasets this can take a significant
amount of time, and it may be better to turn off this flag).
Returns:
list: HT-tensor, which is the constructed approximation.
"""
time = tpc()
rng = jax.random.PRNGKey(seed)
get = jax.jit(jax.vmap(tnv.get, (None, 0)))
optim = optax.adam(lr)
Y = [Yl.copy() for Yl in Y]
state = optim.init(Y)
@jax.jit
def loss(Y, I, y_real):
y = get(Y, I)
# jax.debug.print('{v}', v=y)
l = jnp.mean(jnp.linalg.norm(y_real - y))
return l
loss_grad = jax.jit(jax.grad(loss))
@jax.jit
def optimize(Y, state, I_cur, y_cur):
grads = loss_grad(Y, I_cur, y_cur)
updates, state = optim.update(grads, state)
Y = jax.tree_util.tree_map(lambda u, y: y + u, updates, Y)
return Y, state
for epoch in range(epochs):
rng, key = jax.random.split(rng)
perm = jax.random.permutation(key, I_trn.shape[0])
I_trn_cur = I_trn[perm]
y_trn_cur = y_trn[perm]
for j in range(len(I_trn_cur) // batch):
Y, state = optimize(Y, state,
I_trn_cur[j * batch:(j+1)*batch],
y_trn_cur[j * batch:(j+1)*batch])
if log:
text = f'# {epoch+1:-3d} | '
if log_trn:
y_our = get(Y, I_trn)
e_trn = jnp.linalg.norm(y_trn - y_our) / jnp.linalg.norm(y_trn)
text += f'e_trn : {e_trn:-7.1e} | '
if I_vld is not None and y_vld is not None:
y_our = get(Y, I_vld)
e_vld = jnp.linalg.norm(y_vld - y_our) / jnp.linalg.norm(y_vld)
text += f'e_vld : {e_vld:-7.1e} | '
text += f't : {tpc()-time:-9.2f}'
print(text)
return Y