Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port TensorFlow implementation of the RandomCrop layer #493

Merged
merged 34 commits into from
Aug 6, 2023
Merged
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7740d17
update: port tf implementation of RandomCrop layer
soumik12345 Jul 15, 2023
0d60c8b
update: import + data_format + attempt to replace tf with ops
soumik12345 Jul 15, 2023
efb57ff
update: compute_output_shape returns tuple
soumik12345 Jul 15, 2023
77aa7a3
update: replace tf.convert_to_tensor with ops.core.convert_to_tensor
soumik12345 Jul 15, 2023
72edc5f
update: replace tf.random.uniform with keras_core.random.uniform
soumik12345 Jul 15, 2023
963dc14
update: cropping logic
soumik12345 Jul 16, 2023
5487542
Merge branch 'keras-team:main' into port/random-crop
soumik12345 Jul 17, 2023
43072cb
update: RandomCrop inherits from TFDataLayer + standardized data_format
soumik12345 Jul 17, 2023
7e4801f
update: make use of seed generator
soumik12345 Jul 17, 2023
5325130
update: replace tf.reduce_all with ops.all + pass data_format and bac…
soumik12345 Jul 17, 2023
f4dcd27
update: remove tf.TensorShape
soumik12345 Jul 17, 2023
c02afb2
remove: redundant line
soumik12345 Jul 17, 2023
2446143
update: replace ops with self.backend.ops
soumik12345 Jul 17, 2023
2b6a051
fix: remove unnecessary change
soumik12345 Jul 17, 2023
5848536
fix: linting
soumik12345 Jul 17, 2023
e6296a2
fix: typo
soumik12345 Jul 17, 2023
3c7cf5b
update: inputs
soumik12345 Jul 17, 2023
b030c49
fix: lint
soumik12345 Jul 17, 2023
8eccb03
update: ops.all operation
soumik12345 Jul 18, 2023
dce17bb
update: ops.all operation
soumik12345 Jul 18, 2023
ff9ebdd
update: random_crop dtype
soumik12345 Jul 18, 2023
ba82b84
update: RandomCrop
soumik12345 Jul 18, 2023
f92571c
update: RandomCrop
soumik12345 Jul 18, 2023
8bc71a7
update: RandomCrop
soumik12345 Jul 18, 2023
16e4b3a
update: RandomCrop.py
soumik12345 Jul 18, 2023
8316c2a
Merge branch 'keras-team:main' into port/random-crop
soumik12345 Jul 19, 2023
4f2d6d7
Merge branch 'keras-team:main' into port/random-crop
soumik12345 Jul 22, 2023
ff9fb6c
update: remove commented code
soumik12345 Jul 22, 2023
75d6b7f
update: int32 dtype
soumik12345 Jul 22, 2023
d0f2bb2
update: add squeeze if inputs are not batched
soumik12345 Jul 22, 2023
b9162d9
Merge branch 'keras-team:main' into port/random-crop
soumik12345 Jul 23, 2023
5f8a571
Merge branch 'keras-team:main' into port/random-crop
soumik12345 Jul 24, 2023
9694bc2
update: seed generator
soumik12345 Jul 24, 2023
5b32b46
update: type cast
soumik12345 Jul 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions keras_core/layers/preprocessing/random_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
from keras_core.utils import backend_utils, image_utils
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved
from keras_core.utils.module_utils import tensorflow as tf


H_AXIS = -3
W_AXIS = -2
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved


@keras_core_export("keras_core.layers.RandomCrop")
class RandomCrop(Layer):
"""A preprocessing layer which randomly crops images during training.
Expand Down Expand Up @@ -60,13 +64,9 @@ def __init__(self, height, width, seed=None, name=None, **kwargs):
)

super().__init__(name=name, **kwargs)
self.height = height
self.width = width
self.seed = seed or backend.random.make_default_seed()
self.layer = tf.keras.layers.RandomCrop(
height=height,
width=width,
seed=self.seed,
name=name,
)
self.supports_masking = False
self.supports_jit = False
self._convert_input_args = False
Expand All @@ -75,7 +75,33 @@ def __init__(self, height, width, seed=None, name=None, **kwargs):
def call(self, inputs, training=True):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
outputs = self.layer.call(inputs, training=training)

input_shape = tf.shape(inputs)
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved
h_diff = input_shape[H_AXIS] - self.height
w_diff = input_shape[W_AXIS] - self.width

def random_crop():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid an eager performance overhead, please move this to a method, _randomly_crop_inputs (otherwise redefining the function at each execution is slow)

dtype = input_shape.dtype
rands = tf.random.uniform([2], 0, dtype.max, dtype)
h_start = rands[0] % (h_diff + 1)
w_start = rands[1] % (w_diff + 1)
return tf.image.crop_to_bounding_box(
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved
inputs, h_start, w_start, self.height, self.width
)

def resize():
outputs = image_utils.smart_resize(
inputs, [self.height, self.width]
)
# smart_resize will always output float32, so we need to re-cast.
return tf.cast(outputs, self.compute_dtype)

outputs = tf.cond(
tf.reduce_all((training, h_diff >= 0, w_diff >= 0)),
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved
random_crop,
resize,
)

if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
Expand All @@ -84,9 +110,14 @@ def call(self, inputs, training=True):
return outputs

def compute_output_shape(self, input_shape):
return tuple(self.layer.compute_output_shape(input_shape))
input_shape = tf.TensorShape(input_shape).as_list()
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved
input_shape[H_AXIS] = self.height
input_shape[W_AXIS] = self.width
return tf.TensorShape(input_shape)
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved

def get_config(self):
config = self.layer.get_config()
config.update({"seed": self.seed})
config = super().get_config()
config.update(
{"height": self.height, "width": self.width, "seed": self.seed}
)
return config