Git Product home page Git Product logo

Comments (5)

may- avatar may- commented on June 7, 2024

Hello,

I suppose you can edit \er\target.txt, setting your Bunny class as 1 0 and Cat class as 0 1.
(I'm not sure which task you are working with, but if it is a standard binary classification task, there is more compact and easily reusable code than mine, I guess... i.e. this)

What I tried to solve in this project is relation extraction task. The files under \er\target.txt are one of the possible interpretations of this task, namely, positive label means that the entities have a certain relation, negative label means that the entities don't share that relation.

Moreover, how can I get the actual predictions?

You mean, you want to get prediction in form of 0/1 labels, instead of precision/recall values? Then, you can edit the evaluate function in eval.py, so that it returns 0/1 labels. For example,

# comment out line 56
# loss, eval = sess.run([m.total_loss, m.eval_op], feed_dict=feed) 

# get probability
prob = sess.run(m.logits, feed_dict=feed) 

# get "0/1" labels (in case of threshold = 0.5)
pred = tf.where(prob > 0.5, tf.ones_like(prob), tf.zeros_like(prob))

# return pred

(Sorry, I haven't checked whether it works. so, no guarantee...)

Did I answer your question? If you need further help, let me know which part exactly in my code you didn't understand.

from cnn-re-tf.

boshuru avatar boshuru commented on June 7, 2024

Hi, thanks for your timely reply. I tried your suggestion by adding the following part to the eval.py,

def predict(eval_data,config):
""" Build evaluation graph and run. """

with tf.Graph().as_default():
    with tf.variable_scope('cnn'):
        if config.has_key('contextwise') and config['contextwise']:
            import cnn_context
            m = cnn_context.Model(config, is_train=False)
        else:
            import cnn
            m = cnn.Model(config, is_train=False)
    saver = tf.train.Saver(tf.global_variables())

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(config['train_dir'])
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise IOError("Loading checkpoint file failed!")

        print "\nStart evaluation on test set ...\n"
        if config.has_key('contextwise') and config['contextwise']:
            left_batch, middle_batch, right_batch, y_batch, _ = zip(*eval_data)
            feed = {m.left: np.array(left_batch),
                    m.middle: np.array(middle_batch),
                    m.right: np.array(right_batch),
                    m.labels: np.array(y_batch)}
        else:
            x_batch, y_batch, _ = zip(*eval_data)
            feed = {m.inputs: np.array(x_batch), m.labels: np.array(y_batch)}

        # get probability
        prob = sess.run(m.logits, feed_dict=feed)

        # get "0/1" labels (in case of threshold = 0.5)
        pred = tf.where(prob > 0.5, tf.ones_like(prob), tf.zeros_like(prob))

return prob, pred

And in the main() function:

prob, pred=predict(data,restore_param)
for i in prob:
    print i

I expect the result is a something like [0, 0.999] but instead I got:
[-3.33696699 3.19706154]
[-6.80935621 7.01474667]
[-3.52916455 3.67665386]
[-3.32578492 3.73686266]
[-8.04429531 8.12351418]
[-3.3291328 3.18136191]
[-10.11254215 10.48918533]
...

I am thinking about did I get something wrong in the self.logits part?

from cnn-re-tf.

may- avatar may- commented on June 7, 2024

Hi,

Ah, ok, I think you additionally need sigmoid to convert logits into the probability between 0 and 1.
Please try this:

# get probability
logits = sess.run(m.logits, feed_dict=feed) 
prob = tf.sigmoid(logits)

# get "0/1" labels (in case of threshold = 0.5)
pred = tf.where(prob > 0.5, tf.ones_like(prob), tf.zeros_like(prob))

What returns now your main function? Is it something like [0, 0.999]??

from cnn-re-tf.

boshuru avatar boshuru commented on June 7, 2024

Yes, this time it returns something like:
...
[ 0.03432455 0.96072358]
[ 0.00110219 0.99910235]
[ 0.02849371 0.97531712]
...
By the way, shall the line 156 of cnn.py also be changed
from:
pre, rec = _auc_pr(self._labels, self.logits, threshold * 0.1)
to:
pre, rec = _auc_pr(self._labels, tf.sigmoid(self.logits), threshold * 0.1)
?

from cnn-re-tf.

may- avatar may- commented on June 7, 2024

Oh, yes, you found a bug in my code!! I will correct it. Thanks a lot.

from cnn-re-tf.

Related Issues (8)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.