Hello! I tried to implement the example of implicit differentiation as shown here but with my own functions.
The task is to find mean for a set of vectors named X via gradient descent.
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import jax
import jax.numpy as jnp
from jax import grad, random, jit
from jax import jacobian, hessian, jacfwd, jacrev
key = random.PRNGKey(0)
import jaxopt
from jaxopt import implicit_diff
from jaxopt import linear_solve
from jaxopt import OptaxSolver, GradientDescent
import optax
def euclidean_distance(a, b):
"""
Squared Euclidean distance
"""
return jnp.inner(a - b, a - b)
def weighted_distance(x, X, w):
loss = 0
for i, obj in enumerate(X):
loss += w[i] * euclidean_distance(obj, x)
return loss
def identical(Y, Y_grad):
return Y
Algorithm for finding mean:
# Mean calculation for manifolds with gradient descent
@implicit_diff.custom_root(jax.grad(weighted_distance))
def euclidean_weighted_mean(X_set, weights = None, lr = 0.1, n_iter = 50, plot_loss_flag = False):
if weights == None:
weights = jnp.full((X_set.shape[0]), 1) / X_set.shape[0]
# init mean with random element from set
Y = X_set[np.random.randint(0, X_set.shape[0], (1,))][0]
if plot_loss_flag:
plot_loss = []
prev_loss = 0
plato_iter = 0
plato_reached = False
for i in range(n_iter):
# calculate loss
loss = weighted_distance(Y, X_set, weights)
if plot_loss_flag:
if jnp.allclose(jnp.array(loss), jnp.array(prev_loss)):
if not plato_reached:
plato_iter = i
plato_reached = True
else:
prev_loss = loss
plato_reached = False
Y_grad = grad(weighted_distance, argnums= 0)(Y, X_set, weights)
# calculate Riemannian gradient
riem_grad_Y = Y_grad
# update Y
Y_step = Y - lr * riem_grad_Y
# project new Y on manifold with retraction
Y = Y_step
if plot_loss_flag:
# collect loss for plotting
plot_loss.append(loss)
if plot_loss_flag:
print(f"Total loss: {weighted_distance(Y, X_set, weights)} got in {plato_iter} iterations")
fig, ax = plt.subplots()
ax.plot(plot_loss)
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
plt.show()
return Y
You can launch it like this:
d = 2
m = 4
X = jax.random.uniform(key, (m,d))
euclidean_weighted_mean(X, weights = None, lr = 1e-3, n_iter = 100, plot_loss_flag = True)
As you can see, I am calculating the weighted version of mean and that's where I use jaxopt. Let me define the global objective (just as an example): I want the weights have the value, which minimises the distance between the resulting mean and the desired point. In my case, I want the weights to influence the algorithm in such a way, that the resulting mean will be as close to X[0]
as possible:
def global_task_objective(w, X, target_point, lr, n_iter):
x = euclidean_weighted_mean(X, w, lr = lr, n_iter = n_iter)
loss = euclidean_distance(x, target_point)
return loss, x
target_point = X[0]
w_init = jnp.array(np.random.randn(X.shape[0])) * jnp.square(2 / X.shape[0])
lr = 1e-3
n_iter = 100
global_task_objective(w_init, X, target_point, lr, n_iter)
solver = OptaxSolver(opt=optax.amsgrad(1e-2), fun=global_task_objective, has_aux=True)
state = solver.init_state(w_init, X=X, target_point=target_point, lr=lr, n_iter=n_iter)
The problem emerges when I call
w_init, state = solver.update(params=w_init,
state=state,
X=X, target_point=target_point, lr=lr, n_iter=n_iter)
Meanwhile the official example with Ridge regression works perfectly. Any suggestions?