Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lion optimizer #610

Merged
merged 6 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions keras_core/backend/torch/optimizers/torch_lion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

from keras_core import ops
from keras_core import optimizers
from keras_core.backend.torch.optimizers import torch_parallel_optimizer


class Lion(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Lion):
def _parallel_update_step(
self,
grads,
variables,
learning_rate,
):
keras_variables = variables
variables = [v.value for v in variables]

dtype = variables[0].dtype
lr = ops.cast(learning_rate, dtype)

m_list = [
self._momentums[self._get_variable_index(variable)].value
for variable in keras_variables
]

c_t = torch._foreach_mul(m_list, self.beta_1)
torch._foreach_add_(c_t, grads, alpha=1 - self.beta_1)
c_t = [c.sign() for c in c_t]

torch._foreach_add_(
variables,
torch._foreach_mul(c_t, lr),
alpha=-1,
)

torch._foreach_mul_(m_list, self.beta_2)
torch._foreach_add_(m_list, grads, alpha=1 - self.beta_2)
2 changes: 2 additions & 0 deletions keras_core/backend/torch/optimizers/torch_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __new__(cls, *args, **kwargs):
from keras_core.backend.torch.optimizers import torch_adam
from keras_core.backend.torch.optimizers import torch_adamax
from keras_core.backend.torch.optimizers import torch_adamw
from keras_core.backend.torch.optimizers import torch_lion
from keras_core.backend.torch.optimizers import torch_nadam
from keras_core.backend.torch.optimizers import torch_rmsprop
from keras_core.backend.torch.optimizers import torch_sgd
Expand All @@ -22,6 +23,7 @@ def __new__(cls, *args, **kwargs):
optimizers.Adam: torch_adam.Adam,
optimizers.Adamax: torch_adamax.Adamax,
optimizers.AdamW: torch_adamw.AdamW,
optimizers.Lion: torch_lion.Lion,
optimizers.Nadam: torch_nadam.Nadam,
optimizers.RMSprop: torch_rmsprop.RMSprop,
optimizers.SGD: torch_sgd.SGD,
Expand Down
2 changes: 2 additions & 0 deletions keras_core/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras_core.optimizers.adamax import Adamax
from keras_core.optimizers.adamw import AdamW
from keras_core.optimizers.ftrl import Ftrl
from keras_core.optimizers.lion import Lion
from keras_core.optimizers.nadam import Nadam
from keras_core.optimizers.optimizer import Optimizer
from keras_core.optimizers.rmsprop import RMSprop
Expand All @@ -24,6 +25,7 @@
Adafactor,
Nadam,
Ftrl,
Lion,
}
ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS}

Expand Down
123 changes: 123 additions & 0 deletions keras_core/optimizers/lion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from keras_core import ops
from keras_core.api_export import keras_core_export
from keras_core.optimizers import optimizer


@keras_core_export(["keras_core.optimizers.Lion"])
class Lion(optimizer.Optimizer):
"""Optimizer that implements the Lion algorithm.

The Lion optimizer is a stochastic-gradient-descent method that uses the
sign operator to control the magnitude of the update, unlike other adaptive
optimizers such as Adam that rely on second-order moments. This make
Lion more memory-efficient as it only keeps track of the momentum. According
to the authors (see reference), its performance gain over Adam grows with
the batch size. Because the update of Lion is produced through the sign
operation, resulting in a larger norm, a suitable learning rate for Lion is
typically 3-10x smaller than that for AdamW. The weight decay for Lion
should be in turn 3-10x larger than that for AdamW to maintain a
similar strength (lr * wd).

Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to `0.001`.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
rate to combine the current gradient and the 1st moment estimate.
Defaults to `0.9`.
beta_2: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimate. Defaults to
`0.99`.
{{base_optimizer_keyword_args}}

References:

- [Chen et al., 2023](http://arxiv.org/abs/2302.06675)
- [Authors' implementation](
http://github.com/google/automl/tree/master/lion)

"""

def __init__(
self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.99,
weight_decay=None,
clipnorm=None,
clipvalue=None,
global_clipnorm=None,
use_ema=False,
ema_momentum=0.99,
ema_overwrite_frequency=None,
name="lion",
):
super().__init__(
learning_rate=learning_rate,
name=name,
weight_decay=weight_decay,
clipnorm=clipnorm,
clipvalue=clipvalue,
global_clipnorm=global_clipnorm,
use_ema=use_ema,
ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency,
)
self.beta_1 = beta_1
self.beta_2 = beta_2
if beta_1 <= 0 or beta_1 > 1:
raise ValueError(
"Argument `beta_1` must be in the [0, 1] range. Otherwise, the "
f"optimizer degenerates to SignSGD. Received: beta_1={beta_1}."
)

def build(self, var_list):
"""Initialize optimizer variables.

Lion optimizer has one variable `momentums`.

Args:
var_list: list of model variables to build Lion variables on.
"""
if self.built:
return
super().build(var_list)
self._momentums = []
for var in var_list:
self._momentums.append(
self.add_variable_from_reference(
reference_variable=var, name="m"
)
)

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
lr = ops.cast(learning_rate, variable.dtype)
gradient = ops.cast(gradient, variable.dtype)
beta_1 = ops.cast(self.beta_1, variable.dtype)
beta_2 = ops.cast(self.beta_2, variable.dtype)
m = self._momentums[self._get_variable_index(variable)]

# TODO: currently only support dense gradients
variable.assign_sub(
lr * ops.sign(m * beta_1 + gradient * (1.0 - beta_1))
)
m.assign(m * beta_2 + gradient * (1.0 - beta_2))

def get_config(self):
config = super().get_config()
config.update(
{
"beta_1": self.beta_1,
"beta_2": self.beta_2,
}
)
return config


Lion.__doc__ = Lion.__doc__.replace(
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
)
85 changes: 85 additions & 0 deletions keras_core/optimizers/lion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
import pytest

import keras_core
from keras_core import backend
from keras_core import ops
from keras_core import testing
from keras_core.optimizers.lion import Lion


class LionTest(testing.TestCase):
def test_config(self):
optimizer = Lion(
learning_rate=0.5,
beta_1=0.5,
beta_2=0.67,
)
self.run_class_serialization_test(optimizer)

def test_single_step(self):
optimizer = Lion(learning_rate=0.5)
grads = ops.array([1.0, 6.0, 7.0, 2.0])
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
optimizer.apply_gradients(zip([grads], [vars]))
self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)

def test_weight_decay(self):
grads, var1, var2, var3 = (
ops.zeros(()),
backend.Variable(2.0),
backend.Variable(2.0, name="exclude"),
backend.Variable(2.0),
)
optimizer_1 = Lion(learning_rate=1.0, weight_decay=0.004)
optimizer_1.apply_gradients(zip([grads], [var1]))

optimizer_2 = Lion(learning_rate=1.0, weight_decay=0.004)
optimizer_2.exclude_from_weight_decay(var_names=["exclude"])
optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))

optimizer_3 = Lion(learning_rate=1.0, weight_decay=0.004)
optimizer_3.exclude_from_weight_decay(var_list=[var3])
optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))

self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)
self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)
self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)

def test_correctness_with_golden(self):
optimizer = Lion()

x = backend.Variable(np.ones([10]))
grads = ops.arange(0.1, 1.1, 0.1)
first_grads = ops.full((10,), 0.01)

golden = np.tile(
[[0.999], [0.998], [0.997], [0.996], [0.995]],
(1, 10),
)

optimizer.apply_gradients(zip([first_grads], [x]))
for i in range(5):
self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)
optimizer.apply_gradients(zip([grads], [x]))

def test_clip_norm(self):
optimizer = Lion(clipnorm=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])

def test_clip_value(self):
optimizer = Lion(clipvalue=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [1.0, 1.0])

@pytest.mark.requires_trainable_backend
def test_ema(self):
# TODO: test correctness
model = keras_core.Sequential([keras_core.layers.Dense(10)])
model.compile(optimizer=Lion(use_ema=True), loss="mse")
x = keras_core.ops.zeros((1, 5))
y = keras_core.ops.zeros((1, 10))
model.fit(x, y)