-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port TensorFlow implementation of the RandomCrop
layer
#493
Conversation
@fchollet
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! The last thing you're going to need is compatibility with tf.data
. To achieve this, you just need to:
- Inherit from
TFDataLayer
(look at other preprocessing layers, like CenterCrop for example) - Use
self.backend
when accessing ops - Pass
backend=self.backend
insmart_resize
- Handle seed generators for the tf.data case -- take a look at RandomBrightness for an example. Basically in the tf.data case we need to make a new seed generator.
@fchollet |
Yes -- this is the only way to make it TF Data compatible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update -- the code is looking good!
- What about the test failures?
- Do we have sufficient test coverage here? You can use pytest-cov to check for coverage.
) | ||
|
||
def __init__( | ||
self, height, width, seed=None, data_format=None, name=None, **kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: you could just leave name
to kwargs
h_diff = input_shape[self.height_axis] - self.height | ||
w_diff = input_shape[self.width_axis] - self.width | ||
|
||
def random_crop(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid an eager performance overhead, please move this to a method, _randomly_crop_inputs
(otherwise redefining the function at each execution is slow)
Hi @soumik12345 ! Are you still working on this? Can we help? Should we take it over? |
Hi @fchollet I am stuck with the unit tests. Would really appreciate some help. |
I pulled it and fixed it up -- but unfortunately I'm hitting intractable issues with JAX (it works fine with the other backends): JAX does not support dynamic slicing. More specifically, slice sizes must be static in https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html As far as I can tell this seems to prohibit random cropping in JAX. Will investigate further. |
Finally found a fix for JAX. This is now merged. Thank you! You can check my changes FYI. |
This PR ports the TensorFlow implementation of the
RandomCrop
layer tokeras_core
and aims to end the dependency of this layer ontf.keras
.Linked Issue: keras-team/keras#18442