Comments (5)
Hello,
I suppose you can edit \er\target.txt
, setting your Bunny class as 1 0 and Cat class as 0 1.
(I'm not sure which task you are working with, but if it is a standard binary classification task, there is more compact and easily reusable code than mine, I guess... i.e. this)
What I tried to solve in this project is relation extraction task. The files under \er\target.txt
are one of the possible interpretations of this task, namely, positive label means that the entities have a certain relation, negative label means that the entities don't share that relation.
Moreover, how can I get the actual predictions?
You mean, you want to get prediction in form of 0/1 labels, instead of precision/recall values? Then, you can edit the evaluate
function in eval.py
, so that it returns 0/1 labels. For example,
# comment out line 56
# loss, eval = sess.run([m.total_loss, m.eval_op], feed_dict=feed)
# get probability
prob = sess.run(m.logits, feed_dict=feed)
# get "0/1" labels (in case of threshold = 0.5)
pred = tf.where(prob > 0.5, tf.ones_like(prob), tf.zeros_like(prob))
# return pred
(Sorry, I haven't checked whether it works. so, no guarantee...)
Did I answer your question? If you need further help, let me know which part exactly in my code you didn't understand.
from cnn-re-tf.
Hi, thanks for your timely reply. I tried your suggestion by adding the following part to the eval.py,
def predict(eval_data,config):
""" Build evaluation graph and run. """
with tf.Graph().as_default():
with tf.variable_scope('cnn'):
if config.has_key('contextwise') and config['contextwise']:
import cnn_context
m = cnn_context.Model(config, is_train=False)
else:
import cnn
m = cnn.Model(config, is_train=False)
saver = tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(config['train_dir'])
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise IOError("Loading checkpoint file failed!")
print "\nStart evaluation on test set ...\n"
if config.has_key('contextwise') and config['contextwise']:
left_batch, middle_batch, right_batch, y_batch, _ = zip(*eval_data)
feed = {m.left: np.array(left_batch),
m.middle: np.array(middle_batch),
m.right: np.array(right_batch),
m.labels: np.array(y_batch)}
else:
x_batch, y_batch, _ = zip(*eval_data)
feed = {m.inputs: np.array(x_batch), m.labels: np.array(y_batch)}
# get probability
prob = sess.run(m.logits, feed_dict=feed)
# get "0/1" labels (in case of threshold = 0.5)
pred = tf.where(prob > 0.5, tf.ones_like(prob), tf.zeros_like(prob))
return prob, pred
And in the main() function:
prob, pred=predict(data,restore_param)
for i in prob:
print i
I expect the result is a something like [0, 0.999] but instead I got:
[-3.33696699 3.19706154]
[-6.80935621 7.01474667]
[-3.52916455 3.67665386]
[-3.32578492 3.73686266]
[-8.04429531 8.12351418]
[-3.3291328 3.18136191]
[-10.11254215 10.48918533]
...
I am thinking about did I get something wrong in the self.logits part?
from cnn-re-tf.
Hi,
Ah, ok, I think you additionally need sigmoid to convert logits into the probability between 0 and 1.
Please try this:
# get probability
logits = sess.run(m.logits, feed_dict=feed)
prob = tf.sigmoid(logits)
# get "0/1" labels (in case of threshold = 0.5)
pred = tf.where(prob > 0.5, tf.ones_like(prob), tf.zeros_like(prob))
What returns now your main function? Is it something like [0, 0.999]??
from cnn-re-tf.
Yes, this time it returns something like:
...
[ 0.03432455 0.96072358]
[ 0.00110219 0.99910235]
[ 0.02849371 0.97531712]
...
By the way, shall the line 156 of cnn.py also be changed
from:
pre, rec = _auc_pr(self._labels, self.logits, threshold * 0.1)
to:
pre, rec = _auc_pr(self._labels, tf.sigmoid(self.logits), threshold * 0.1)
?
from cnn-re-tf.
Oh, yes, you found a bug in my code!! I will correct it. Thanks a lot.
from cnn-re-tf.
Related Issues (8)
- How to prepare the source.att file HOT 1
- How do you create the entities.pickle file? HOT 4
- STANFORD NER HOT 7
- Dataset format and input format for new predictions HOT 4
- distant supervision script exists with error HOT 2
- Did you optimize F1 specifically
- TypeError: object of type 'NoneType' has no len() with #3 settings
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cnn-re-tf.