Git Product home page Git Product logo

Comments (6)

Paryavi avatar Paryavi commented on June 3, 2024

@sachinprasadhs @fchollet @ianstenbit
To resolve the error in this thread, should I limit max_box to 32 again in inference(after model training) as well, how?
I guess the issue is with unbatching; Copying from documentation;

visualization_ds = eval_ds.unbatch()
visualization_ds = visualization_ds.ragged_batch(16)
visualization_ds = visualization_ds.shuffle(8)

FYI, the image data I use are small, so I padded to 640 by 640 pixels using a resizing layer for training, should I use the resizing layer somehow in inference as well?

from keras-cv.

innat-asj avatar innat-asj commented on June 3, 2024

@Paryavi I think your issue is ragged tensor with keras 3, which doesn't support yet.

NotImplementedError: bounding_box.to_ragged was called using a backend which does not support ragged tensors. Current backend: tensorflow.

Maybe, you can do padding instead.

preprocessor = keras.Sequential(
    layers=[
        keras_cv.layers.Resizing(
            input_shape, 
            input_shape,
            bounding_box_format=bbox_format,
            pad_to_aspect_ratio=True
        ),
    ],
)

def pad_fn(inputs):
    inputs["bounding_boxes"] = keras_cv.bounding_box.to_dense(
        inputs["bounding_boxes"], max_boxes=32
    )
    return inputs
visualization_ds = eval_ds.unbatch()
visualization_ds = visualization_ds.ragged_batch(16)
visualization_ds = visualization_ds.map(
        preprocessor, num_parallel_calls=tf.data.AUTOTUNE
    ) 
visualization_ds= visualization_ds.map(
        pad_fn, num_parallel_calls=tf.data.AUTOTUNE
    )
visualization_ds= visualization_ds.prefetch(tf.data.AUTOTUNE)

from keras-cv.

Paryavi avatar Paryavi commented on June 3, 2024

Thanks @innat-asj
I use the padding in the training the model section, and also I used your padding code after training the model (.fit), but I get this error when running the last part of your code;

visualization_ds = visualization_ds.map(
preprocessor, num_parallel_calls=tf.data.AUTOTUNE
)
visualization_ds= visualization_ds.map(
pad_fn, num_parallel_calls=tf.data.AUTOTUNE
)
visualization_ds= visualization_ds.prefetch(tf.data.AUTOTUNE)

Error:
TypeError Traceback (most recent call last)
in <cell line: 1>()
----> 1 visualization_ds = visualization_ds.map(
2 preprocessor, num_parallel_calls=tf.data.AUTOTUNE
3 )
4 visualization_ds= visualization_ds.map(
5 pad_fn, num_parallel_calls=tf.data.AUTOTUNE

18 frames
/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
155 bound_signature = None
156 try:
--> 157 return fn(*args, **kwargs)
158 except Exception as e:
159 if hasattr(e, "_keras_call_info_injected"):

TypeError: Sequential.call() got multiple values for argument 'training'

from keras-cv.

divyashreepathihalli avatar divyashreepathihalli commented on June 3, 2024

@Paryavi what backend are you using? and what is the input tensor's backend? because JAX and pytorch does not support ragged tensors.

from keras-cv.

github-actions avatar github-actions commented on June 3, 2024

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

from keras-cv.

github-actions avatar github-actions commented on June 3, 2024

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

from keras-cv.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.