Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port TensorFlow implementation of the RandomCrop layer #493

Merged
merged 34 commits into from
Aug 6, 2023
Merged
Changes from 31 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7740d17
update: port tf implementation of RandomCrop layer
soumik12345 Jul 15, 2023
0d60c8b
update: import + data_format + attempt to replace tf with ops
soumik12345 Jul 15, 2023
efb57ff
update: compute_output_shape returns tuple
soumik12345 Jul 15, 2023
77aa7a3
update: replace tf.convert_to_tensor with ops.core.convert_to_tensor
soumik12345 Jul 15, 2023
72edc5f
update: replace tf.random.uniform with keras_core.random.uniform
soumik12345 Jul 15, 2023
963dc14
update: cropping logic
soumik12345 Jul 16, 2023
5487542
Merge branch 'keras-team:main' into port/random-crop
soumik12345 Jul 17, 2023
43072cb
update: RandomCrop inherits from TFDataLayer + standardized data_format
soumik12345 Jul 17, 2023
7e4801f
update: make use of seed generator
soumik12345 Jul 17, 2023
5325130
update: replace tf.reduce_all with ops.all + pass data_format and bac…
soumik12345 Jul 17, 2023
f4dcd27
update: remove tf.TensorShape
soumik12345 Jul 17, 2023
c02afb2
remove: redundant line
soumik12345 Jul 17, 2023
2446143
update: replace ops with self.backend.ops
soumik12345 Jul 17, 2023
2b6a051
fix: remove unnecessary change
soumik12345 Jul 17, 2023
5848536
fix: linting
soumik12345 Jul 17, 2023
e6296a2
fix: typo
soumik12345 Jul 17, 2023
3c7cf5b
update: inputs
soumik12345 Jul 17, 2023
b030c49
fix: lint
soumik12345 Jul 17, 2023
8eccb03
update: ops.all operation
soumik12345 Jul 18, 2023
dce17bb
update: ops.all operation
soumik12345 Jul 18, 2023
ff9ebdd
update: random_crop dtype
soumik12345 Jul 18, 2023
ba82b84
update: RandomCrop
soumik12345 Jul 18, 2023
f92571c
update: RandomCrop
soumik12345 Jul 18, 2023
8bc71a7
update: RandomCrop
soumik12345 Jul 18, 2023
16e4b3a
update: RandomCrop.py
soumik12345 Jul 18, 2023
8316c2a
Merge branch 'keras-team:main' into port/random-crop
soumik12345 Jul 19, 2023
4f2d6d7
Merge branch 'keras-team:main' into port/random-crop
soumik12345 Jul 22, 2023
ff9fb6c
update: remove commented code
soumik12345 Jul 22, 2023
75d6b7f
update: int32 dtype
soumik12345 Jul 22, 2023
d0f2bb2
update: add squeeze if inputs are not batched
soumik12345 Jul 22, 2023
b9162d9
Merge branch 'keras-team:main' into port/random-crop
soumik12345 Jul 23, 2023
5f8a571
Merge branch 'keras-team:main' into port/random-crop
soumik12345 Jul 24, 2023
9694bc2
update: seed generator
soumik12345 Jul 24, 2023
5b32b46
update: type cast
soumik12345 Jul 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 110 additions & 30 deletions keras_core/layers/preprocessing/random_crop.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import numpy as np

from keras_core import backend
from keras_core import ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.random.seed_generator import SeedGenerator
from keras_core.utils import backend_utils
from keras_core.utils.module_utils import tensorflow as tf
from keras_core.utils import image_utils


@keras_core_export("keras_core.layers.RandomCrop")
class RandomCrop(Layer):
class RandomCrop(TFDataLayer):
"""A preprocessing layer which randomly crops images during training.

During training, this layer will randomly choose a location to crop images
Expand Down Expand Up @@ -52,41 +52,121 @@ class RandomCrop(Layer):
`name` and `dtype`.
"""

def __init__(self, height, width, seed=None, name=None, **kwargs):
if not tf.available:
raise ImportError(
"Layer RandomCrop requires TensorFlow. "
"Install it via `pip install tensorflow`."
)

def __init__(
self, height, width, seed=None, data_format=None, name=None, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: you could just leave name to kwargs

):
super().__init__(name=name, **kwargs)
self.height = height
self.width = width
self.seed = seed or backend.random.make_default_seed()
self.layer = tf.keras.layers.RandomCrop(
height=height,
width=width,
seed=self.seed,
name=name,
)
self.seed_generator = SeedGenerator(seed)
self.data_format = backend.standardize_data_format(data_format)

if self.data_format == "channels_first":
self.heigh_axis = -2
self.width_axis = -1
elif self.data_format == "channels_last":
self.height_axis = -3
self.width_axis = -2

self.supports_masking = False
self.supports_jit = False
self._convert_input_args = False
self._allow_non_tensor_positional_args = True

def call(self, inputs, training=True):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
outputs = self.layer.call(inputs, training=training)
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
inputs = self.backend.cast(inputs, self.compute_dtype)
input_shape = self.backend.shape(inputs)
is_batched = len(input_shape) > 3
inputs = (
self.backend.numpy.expand_dims(inputs, axis=0)
if not is_batched
else inputs
)

h_diff = input_shape[self.height_axis] - self.height
w_diff = input_shape[self.width_axis] - self.width

def random_crop():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid an eager performance overhead, please move this to a method, _randomly_crop_inputs (otherwise redefining the function at each execution is slow)

input_height, input_width = (
input_shape[self.height_axis],
input_shape[self.width_axis],
)

h_start = self.backend.cast(
ops.random.uniform(
(),
0,
maxval=float(input_height - self.height + 1),
dtype=inputs.dtype,
seed=self.seed_generator,
),
"int32",
)
w_start = self.backend.cast(
ops.random.uniform(
(),
0,
maxval=float(input_width - self.width + 1),
dtype=inputs.dtype,
seed=self.seed_generator,
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved
),
"int32",
)
if self.data_format == "channels_last":
return inputs[
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
]
else:
return inputs[
:,
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
]

def resize():
outputs = image_utils.smart_resize(
inputs,
[self.height, self.width],
data_format=self.data_format,
backend_module=self.backend,
)
# smart_resize will always output float32, so we need to re-cast.
return self.backend.cast(outputs, self.compute_dtype)

outputs = self.backend.cond(
self.backend.numpy.all((training, h_diff >= 0, w_diff >= 0)),
random_crop,
resize,
)

outputs = (
self.backend.numpy.squeeze(outputs, axis=0)
if not is_batched
else outputs
)

if self.backend != "tensorflow" and not backend_utils.in_tf_graph():
outputs = self.backend.convert_to_tensor(outputs)
return outputs

def compute_output_shape(self, input_shape):
return tuple(self.layer.compute_output_shape(input_shape))
def compute_output_shape(self, input_shape, *args, **kwargs):
input_shape = list(input_shape)
input_shape[self.height_axis] = self.height
input_shape[self.width_axis] = self.width
return tuple(input_shape)

def get_config(self):
config = self.layer.get_config()
config.update({"seed": self.seed})
config = super().get_config()
config.update(
{
"height": self.height,
"width": self.width,
"seed": self.seed,
"data_format": self.data_format,
}
)
return config