Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port TensorFlow implementation of the RandomCrop layer #493

Merged
merged 34 commits into from
Aug 6, 2023

Conversation

soumik12345
Copy link
Contributor

@soumik12345 soumik12345 commented Jul 15, 2023

This PR ports the TensorFlow implementation of the RandomCrop layer to keras_core and aims to end the dependency of this layer on tf.keras.

Linked Issue: keras-team/keras#18442

@soumik12345
Copy link
Contributor Author

@fchollet
I ran the corresponding tests locally, and two of them failed.

keras_core/layers/preprocessing/random_crop_test.py::RandomCropTest::test_predicting_with_longer_height FAILED                [ 16%]
keras_core/layers/preprocessing/random_crop_test.py::RandomCropTest::test_predicting_with_longer_width PASSED                 [ 33%]
keras_core/layers/preprocessing/random_crop_test.py::RandomCropTest::test_random_crop PASSED                                  [ 50%]
keras_core/layers/preprocessing/random_crop_test.py::RandomCropTest::test_random_crop_full PASSED                             [ 66%]
keras_core/layers/preprocessing/random_crop_test.py::RandomCropTest::test_random_crop_partial PASSED                          [ 83%]
keras_core/layers/preprocessing/random_crop_test.py::RandomCropTest::test_tf_data_compatibility FAILED                        [100%]

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! 👍

keras_core/layers/preprocessing/random_crop.py Outdated Show resolved Hide resolved
keras_core/layers/preprocessing/random_crop.py Outdated Show resolved Hide resolved
keras_core/layers/preprocessing/random_crop.py Outdated Show resolved Hide resolved
keras_core/layers/preprocessing/random_crop.py Outdated Show resolved Hide resolved
keras_core/layers/preprocessing/random_crop.py Outdated Show resolved Hide resolved
@soumik12345 soumik12345 requested a review from fchollet July 16, 2023 00:52
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! The last thing you're going to need is compatibility with tf.data. To achieve this, you just need to:

  • Inherit from TFDataLayer (look at other preprocessing layers, like CenterCrop for example)
  • Use self.backend when accessing ops
  • Pass backend=self.backend in smart_resize
  • Handle seed generators for the tf.data case -- take a look at RandomBrightness for an example. Basically in the tf.data case we need to make a new seed generator.

keras_core/layers/preprocessing/random_crop.py Outdated Show resolved Hide resolved
@soumik12345
Copy link
Contributor Author

Use self.backend when accessing ops

@fchollet
Should all ops be replace by self.backend.ops?

@soumik12345 soumik12345 requested a review from fchollet July 17, 2023 21:48
@fchollet
Copy link
Contributor

Yes -- this is the only way to make it TF Data compatible.

@soumik12345
Copy link
Contributor Author

@fchollet
I'm getting this error in the CI for TF tests

AttributeError: module 'keras_core.backend.tensorflow' has no attribute 'ops'

@soumik12345 soumik12345 requested a review from fchollet July 22, 2023 07:24
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update -- the code is looking good!

  • What about the test failures?
  • Do we have sufficient test coverage here? You can use pytest-cov to check for coverage.

)

def __init__(
self, height, width, seed=None, data_format=None, name=None, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: you could just leave name to kwargs

h_diff = input_shape[self.height_axis] - self.height
w_diff = input_shape[self.width_axis] - self.width

def random_crop():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid an eager performance overhead, please move this to a method, _randomly_crop_inputs (otherwise redefining the function at each execution is slow)

keras_core/layers/preprocessing/random_crop.py Outdated Show resolved Hide resolved
@fchollet
Copy link
Contributor

fchollet commented Aug 5, 2023

Hi @soumik12345 ! Are you still working on this? Can we help? Should we take it over?

@soumik12345
Copy link
Contributor Author

Hi @soumik12345 ! Are you still working on this? Can we help? Should we take it over?

Hi @fchollet I am stuck with the unit tests. Would really appreciate some help.

@fchollet
Copy link
Contributor

fchollet commented Aug 5, 2023

I pulled it and fixed it up -- but unfortunately I'm hitting intractable issues with JAX (it works fine with the other backends): JAX does not support dynamic slicing. More specifically, slice sizes must be static in https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html

As far as I can tell this seems to prohibit random cropping in JAX. Will investigate further.

@fchollet fchollet merged commit bee023d into keras-team:main Aug 6, 2023
@fchollet
Copy link
Contributor

fchollet commented Aug 6, 2023

Finally found a fix for JAX. This is now merged. Thank you! You can check my changes FYI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants