yu54ku / xml-cnn Goto Github PK
View Code? Open in Web Editor NEWImplementation of "Deep Learning for Extreme Multi-label Text Classification" using PyTorch.
License: MIT License
Implementation of "Deep Learning for Extreme Multi-label Text Classification" using PyTorch.
License: MIT License
On the home page you stated that "Caution: This dataset is tokenized
differently than the one used by Liu et al."
May I ask more details about this?
Does this mean Liu et al. didn't use tokens provided at
the RCV1 page?
Thanks
Hello and thank you for this implementation.
I would like to ask if there is a way to load the data in batches since my dataset is quite large (50 GB) and cannot fit into the memory.
As the title. I can run it on one machine but on another machine
it failed with the following message. Do you see why this happened?
Many thanks
$ python3 train.py
============================== Normal Train Mode ==============================
------------------------------------ Params -----------------------------------
[('batch_size', 64), ('cache_path', 'cache'), ('measure', 'p@1'), ('sequence_length', 500)]
-------------------------------------------------------------------------------
--------------------------------- Hyper Params --------------------------------
[('d_max_pool_p', [125, 128, 128]), ('filter_channels', 128), ('filter_sizes', [2, 4, 8]), ('hidden_dims', 512), ('learning_rate', 0.0005099137446356937), ('stride', [2, 1, 1])]
-------------------------------------------------------------------------------
Loading data... /home/cjlin/.local/lib/python3.6/site-packages/torchtext/data/field.py:36: UserWarning: RawField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.
warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)
/home/cjlin/.local/lib/python3.6/site-packages/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.
warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)
/home/cjlin/.local/lib/python3.6/site-packages/torchtext/data/example.py:68: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.
warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)
/home/cjlin/.local/lib/python3.6/site-packages/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.
warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)
Done.
Converting text to ID... Traceback (most recent call last):
File "train.py", line 129, in <module>
main()
File "train.py", line 106, in main
trainer.preprocess()
File "/home/cjlin/xml-cnn/build_problem.py", line 151, in preprocess
self.TEXT.vocab.load_vectors("glove.6B.300d")
File "/home/cjlin/.local/lib/python3.6/site-packages/torchtext/vocab.py", line 184, in load_vectors
vectors[idx] = pretrained_aliases[vector](**kwargs)
File "/home/cjlin/.local/lib/python3.6/site-packages/torchtext/vocab.py", line 487, in __init__
super(GloVe, self).__init__(name, url=url, **kwargs)
File "/home/cjlin/.local/lib/python3.6/site-packages/torchtext/vocab.py", line 326, in __init__
self.cache(name, cache, url=url, max_vectors=max_vectors)
File "/home/cjlin/.local/lib/python3.6/site-packages/torchtext/vocab.py", line 368, in cache
with zipfile.ZipFile(dest, "r") as zf:
File "/usr/lib/python3.6/zipfile.py", line 1131, in __init__
self._RealGetContents()
File "/usr/lib/python3.6/zipfile.py", line 1198, in _RealGetContents
raise BadZipFile("File is not a zip file")
zipfile.BadZipFile: File is not a zip file
I noticed that you have
d_max_pool_p: [125, 128, 128]
filter_sizes: [2, 4, 8]
stride: [2, 1, 1]
May I ask why the first stride is 2?
Also how did you decide values shown in params.yml? The authors of
xml-CNN mentioned
filter_channels: 128
filter_sizes: [2, 4, 8]
hidden_dims: 512
but how about others?
Thanks
My two runs of the code showed
-------------- Best Epoch: 16 (p@1: 0.95296472311019897461) -------------
-------------- Best Epoch: 14 (p@1: 0.95154267549514770508) -------------
This p@1 result is slightly worse than the 96.86 presented
in the xml-cnn paper. Do you think the reason is because
of the slightly different tokenization?
Indeed I think as yours used more tokens, more information is used
and results should be as good?
Thanks
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.