Git Product home page Git Product logo

Comments (9)

mpekalski avatar mpekalski commented on July 26, 2024

I mean it finishes training on first batch but then fails on evaluation.

from importance-sampling.

mpekalski avatar mpekalski commented on July 26, 2024

So in my case the shapes of outputs in evaluate_batch within morel_wrappers.py are

[(), (128, 1), (128, 1), (128, 1), (1,)]

in case of metrics=["accuracy"] the shapes look like

[(), (128, 1), (128, 1), (128, 1), (128, 1)]
[(), (128, 1), (128, 1), (128, 1), (128, 1)]
[(), (128, 1), (128, 1), (128, 1), (128, 1)]
[(), (128, 1), (128, 1), (128, 1), (128, 1)]
[(), (70, 1), (70, 1), (70, 1), (70, 1)]

from importance-sampling.

mpekalski avatar mpekalski commented on July 26, 2024

Tracking further, looks like the metric passed as a function is not applied to each sample in a batch, o

I've added comments around the part that creates MetricLayer to see what is the output for accuracy and my own metric and they differ in shape. Looks like one has batch coefficient (?) and the other does not.

        print(f"metrics: {metrics}")
        metrics = [
            MetricLayer(metric)([y_true, model.get_output_at(0)])
            for metric in metrics
        ]
        print(f"metrics: {metrics}")

Output

metrics: ['accuracy']
metrics: [<tf.Tensor 'metric_layer_1/ExpandDims_1:0' shape=(?, 1) dtype=float32>]
metrics: [<function matthews_correlation at 0x7f65ae953c80>]
metrics: [<tf.Tensor 'metric_layer_1/ExpandDims_1:0' shape=(1,) dtype=float32>]

from importance-sampling.

mpekalski avatar mpekalski commented on July 26, 2024

In a call method within layers/metrics.py it looks like the input in both cases 'accuracy' and metric passed as a function gets the same input (in terms of shape)

    def call(self, inputs, mask=None):
        # Compute the metric
        metric = self.metric_func(*inputs)
        print(f"inputs {inputs}")
inputs [<tf.Tensor 'input_2:0' shape=(?, 1) dtype=float32>, <tf.Tensor 'activation_1/Softmax:0' shape=(?, 1) dtype=float32>]

but still metric passed as a function returns different shape than passed as a string.
Seems like unpacking, i.e., *inputs, does not work as intended.

from importance-sampling.

mpekalski avatar mpekalski commented on July 26, 2024

For me a fix is to change (in layers/metrics.py)

metric = self.metric_func(*inputs)

to

        metric = K.map_fn(lambda x: self.metric_func(x[0],x[1]),   K.concatenate(inputs, axis=-1))

but this solution is not generic enough. And here you cannot use *x to unpack a tensor.

from importance-sampling.

mpekalski avatar mpekalski commented on July 26, 2024

The problem I see now, with my fix is that I do see one value for my metric (for the whole batch) but I think it is wrong as this metric should not be calculated on a single sample but whole batch. Hence, I do not know if my PR really fixed the problem and the problem is with metric, or the other way around.

from importance-sampling.

mpekalski avatar mpekalski commented on July 26, 2024

Ok, so the problem is that my metric needs to be evaluated on the whole batch and does not make sense when calculated on a single sample, hence it cannot be fed to mean().

from importance-sampling.

angeloskath avatar angeloskath commented on July 26, 2024

Hi,

Thank you for the contribution and sorry for the relatively slow reply. I pushed a fix at branch scalar-metrics (e5dac90). I also added a test. In case you have use case that is not covered, you can add it there. I will merge and do another release in due time.

Thanks,
Angelos

from importance-sampling.

mpekalski avatar mpekalski commented on July 26, 2024

No problem. I see your solution is more elegant than mine. :)

from importance-sampling.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.