-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port TensorFlow implementation of the RandomCrop
layer
#493
Merged
Merged
Changes from 31 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
7740d17
update: port tf implementation of RandomCrop layer
soumik12345 0d60c8b
update: import + data_format + attempt to replace tf with ops
soumik12345 efb57ff
update: compute_output_shape returns tuple
soumik12345 77aa7a3
update: replace tf.convert_to_tensor with ops.core.convert_to_tensor
soumik12345 72edc5f
update: replace tf.random.uniform with keras_core.random.uniform
soumik12345 963dc14
update: cropping logic
soumik12345 5487542
Merge branch 'keras-team:main' into port/random-crop
soumik12345 43072cb
update: RandomCrop inherits from TFDataLayer + standardized data_format
soumik12345 7e4801f
update: make use of seed generator
soumik12345 5325130
update: replace tf.reduce_all with ops.all + pass data_format and bac…
soumik12345 f4dcd27
update: remove tf.TensorShape
soumik12345 c02afb2
remove: redundant line
soumik12345 2446143
update: replace ops with self.backend.ops
soumik12345 2b6a051
fix: remove unnecessary change
soumik12345 5848536
fix: linting
soumik12345 e6296a2
fix: typo
soumik12345 3c7cf5b
update: inputs
soumik12345 b030c49
fix: lint
soumik12345 8eccb03
update: ops.all operation
soumik12345 dce17bb
update: ops.all operation
soumik12345 ff9ebdd
update: random_crop dtype
soumik12345 ba82b84
update: RandomCrop
soumik12345 f92571c
update: RandomCrop
soumik12345 8bc71a7
update: RandomCrop
soumik12345 16e4b3a
update: RandomCrop.py
soumik12345 8316c2a
Merge branch 'keras-team:main' into port/random-crop
soumik12345 4f2d6d7
Merge branch 'keras-team:main' into port/random-crop
soumik12345 ff9fb6c
update: remove commented code
soumik12345 75d6b7f
update: int32 dtype
soumik12345 d0f2bb2
update: add squeeze if inputs are not batched
soumik12345 b9162d9
Merge branch 'keras-team:main' into port/random-crop
soumik12345 5f8a571
Merge branch 'keras-team:main' into port/random-crop
soumik12345 9694bc2
update: seed generator
soumik12345 5b32b46
update: type cast
soumik12345 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,14 @@ | ||
import numpy as np | ||
|
||
from keras_core import backend | ||
from keras_core import ops | ||
from keras_core.api_export import keras_core_export | ||
from keras_core.layers.layer import Layer | ||
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer | ||
from keras_core.random.seed_generator import SeedGenerator | ||
from keras_core.utils import backend_utils | ||
from keras_core.utils.module_utils import tensorflow as tf | ||
from keras_core.utils import image_utils | ||
|
||
|
||
@keras_core_export("keras_core.layers.RandomCrop") | ||
class RandomCrop(Layer): | ||
class RandomCrop(TFDataLayer): | ||
"""A preprocessing layer which randomly crops images during training. | ||
|
||
During training, this layer will randomly choose a location to crop images | ||
|
@@ -52,41 +52,121 @@ class RandomCrop(Layer): | |
`name` and `dtype`. | ||
""" | ||
|
||
def __init__(self, height, width, seed=None, name=None, **kwargs): | ||
if not tf.available: | ||
raise ImportError( | ||
"Layer RandomCrop requires TensorFlow. " | ||
"Install it via `pip install tensorflow`." | ||
) | ||
|
||
def __init__( | ||
self, height, width, seed=None, data_format=None, name=None, **kwargs | ||
): | ||
super().__init__(name=name, **kwargs) | ||
self.height = height | ||
self.width = width | ||
self.seed = seed or backend.random.make_default_seed() | ||
self.layer = tf.keras.layers.RandomCrop( | ||
height=height, | ||
width=width, | ||
seed=self.seed, | ||
name=name, | ||
) | ||
self.seed_generator = SeedGenerator(seed) | ||
self.data_format = backend.standardize_data_format(data_format) | ||
|
||
if self.data_format == "channels_first": | ||
self.heigh_axis = -2 | ||
self.width_axis = -1 | ||
elif self.data_format == "channels_last": | ||
self.height_axis = -3 | ||
self.width_axis = -2 | ||
|
||
self.supports_masking = False | ||
self.supports_jit = False | ||
self._convert_input_args = False | ||
self._allow_non_tensor_positional_args = True | ||
|
||
def call(self, inputs, training=True): | ||
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): | ||
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs)) | ||
outputs = self.layer.call(inputs, training=training) | ||
if ( | ||
backend.backend() != "tensorflow" | ||
and not backend_utils.in_tf_graph() | ||
): | ||
outputs = backend.convert_to_tensor(outputs) | ||
inputs = self.backend.cast(inputs, self.compute_dtype) | ||
input_shape = self.backend.shape(inputs) | ||
is_batched = len(input_shape) > 3 | ||
inputs = ( | ||
self.backend.numpy.expand_dims(inputs, axis=0) | ||
if not is_batched | ||
else inputs | ||
) | ||
|
||
h_diff = input_shape[self.height_axis] - self.height | ||
w_diff = input_shape[self.width_axis] - self.width | ||
|
||
def random_crop(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid an eager performance overhead, please move this to a method, |
||
input_height, input_width = ( | ||
input_shape[self.height_axis], | ||
input_shape[self.width_axis], | ||
) | ||
|
||
h_start = self.backend.cast( | ||
ops.random.uniform( | ||
(), | ||
0, | ||
maxval=float(input_height - self.height + 1), | ||
dtype=inputs.dtype, | ||
seed=self.seed_generator, | ||
), | ||
"int32", | ||
) | ||
w_start = self.backend.cast( | ||
ops.random.uniform( | ||
(), | ||
0, | ||
maxval=float(input_width - self.width + 1), | ||
dtype=inputs.dtype, | ||
seed=self.seed_generator, | ||
soumik12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
), | ||
"int32", | ||
) | ||
if self.data_format == "channels_last": | ||
return inputs[ | ||
:, | ||
h_start : h_start + self.height, | ||
w_start : w_start + self.width, | ||
] | ||
else: | ||
return inputs[ | ||
:, | ||
:, | ||
h_start : h_start + self.height, | ||
w_start : w_start + self.width, | ||
] | ||
|
||
def resize(): | ||
outputs = image_utils.smart_resize( | ||
inputs, | ||
[self.height, self.width], | ||
data_format=self.data_format, | ||
backend_module=self.backend, | ||
) | ||
# smart_resize will always output float32, so we need to re-cast. | ||
return self.backend.cast(outputs, self.compute_dtype) | ||
|
||
outputs = self.backend.cond( | ||
self.backend.numpy.all((training, h_diff >= 0, w_diff >= 0)), | ||
random_crop, | ||
resize, | ||
) | ||
|
||
outputs = ( | ||
self.backend.numpy.squeeze(outputs, axis=0) | ||
if not is_batched | ||
else outputs | ||
) | ||
|
||
if self.backend != "tensorflow" and not backend_utils.in_tf_graph(): | ||
outputs = self.backend.convert_to_tensor(outputs) | ||
return outputs | ||
|
||
def compute_output_shape(self, input_shape): | ||
return tuple(self.layer.compute_output_shape(input_shape)) | ||
def compute_output_shape(self, input_shape, *args, **kwargs): | ||
input_shape = list(input_shape) | ||
input_shape[self.height_axis] = self.height | ||
input_shape[self.width_axis] = self.width | ||
return tuple(input_shape) | ||
|
||
def get_config(self): | ||
config = self.layer.get_config() | ||
config.update({"seed": self.seed}) | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"height": self.height, | ||
"width": self.width, | ||
"seed": self.seed, | ||
"data_format": self.data_format, | ||
} | ||
) | ||
return config |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: you could just leave
name
tokwargs