Comments (9)
cc @muammar
In addition to this example, I'd also link to integration with a Scikit-learn wrapper for PyTorch skorch and Dask-ML's ParallelPostFit.
from dask-examples.
Should have an example ready tomorrow.
hopefully doesn't get too much into Torch or a dataset.
I think we'll want to go into some detail about torch.utils.data.Dataset
, because it's not 100% straightforward how to get the data loaded onto workers. To predict for a directory of images, I had to write the following myself
import glob
from PIL import Image
def default_loader(path, fs=__builtins__):
with fs.open(path, 'rb') as f:
img = Image.open(f).convert("RGB")
return img
class FileDataset(torch.utils.data.Dataset):
def __init__(self, files, transform=None, target_transform=None,
classes=None,
loader=default_loader):
self.files = files
self.transform = transform
self.target_transform = target_transform
self.loader = loader
if classes is None:
classes = list(sorted(set(x.split(os.path.sep)[-2] for x in files)))
else:
classes = list(classes)
self.classes = classes
def __len__(self):
return len(self.files)
def __getitem__(self, index):
filename = self.files[index]
img = self.loader(filename)
target = self.classes.index(filename.split(os.path.sep)[-2])
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
and use it as
files = glob.glob("hymenoptera_data/val/*/*.jpg")
dataset = FileDataset(files, transform=data_transforms['val'])
For s3, the usage would be FileDataset(files, ..., loader=functools.partial(default_loader, fs=s3fs.S3FileSystem(...))
. As a relative newcomer to PyTorch, writing that wasn't 100% straightforward.
Things seem to be working out well after that. PyTorch models seem to (de)serialize much better than tensorflow's did last time I tried.
from dask-examples.
Do we need to use the Torch Dataset API here?
because it's not 100% straightforward how to get the data loaded onto workers
I guess my hope is that, for image data at least, we could just pass around Numpy arrays. So we might created dask.delayed objects using skimage.io.imread
or something similar. (maybe like https://blog.dask.org/2019/06/20/load-image-data , but before the dask array bit)
from dask-examples.
Also, if you haven't seen it, this video is nice: https://developer.download.nvidia.com/video/gputechconf/gtc/2019/video/S9198/s9198-dask-and-v100s-for-fast-distributed-batch-scoring-of-computer-vision-workloads.mp4
from dask-examples.
from dask-examples.
from dask-examples.
I'm curious about inputs of Dask Arrays and outputs of model predictions too. I think PyTorch Dataset
s will need to play an intermediate role; at least that's what skorch uses when tracing net.py's Net.predict to net.py#L1150.
Distributed training would also be interesting of course, but my guess is that that's more of an open problem
It's also mentioned in dask/distributed#2581
from dask-examples.
Skorch looks interesting to me. Can the wrapper be used after loading the model from disk where the wrapper was not used?
I've practiced applying the dask-ml parallelpostfit wrapper on a pre-trained model and I remember having to do a few manual steps before running predictions. I need to dig up that code.
from dask-examples.
Can the wrapper be used after loading the model from disk where the wrapper was not used?
Yup. The underlying model is an attribute (.module_
), so it's simple:
import torch
from skorch import NeuralNetClassifier
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
...
model = Net()
# Train model
# Save trained model using PyTorch
torch.save(model.state_dict(), "trained_model.pt")
# Use skorch later (not necessarily the training session)
sk_net = NeuralNetClassifier(Net)
sk_net.initialize()
# Load parameters saved with PyTorch
sk_net.module_.load_state_dict(torch.load("trained_model.pt"))
from dask-examples.
Related Issues (20)
- Can not use "conda install" in binder environment HOT 12
- Placement of interactive dashboard in JupyterLab HOT 2
- Move default branch from "master" -> "main"
- json-data-on-the-web CI check timing out HOT 2
- Binder badge links give 404 errors from survey result notebooks HOT 4
- Remove binder banner in favour of theme banner
- Dashboard needs to be set up every time HOT 1
- Update dependencies and ensure all notebooks are working HOT 5
- Running the Bag example several times consecutively results in a `JSONDecodeError`
- Automatically clear notebook output HOT 1
- XGBoost example notebook uses deprecated dask-xgboost HOT 2
- Website missing `dask-sphinx-theme` font HOT 2
- Create an Example of Using TPOT Using Dataset that DOESN'T Fit in Memory HOT 1
- Large scale XGBoost example with HyperParameter Optimization HOT 21
- Binder build fails with conda conflict
- Attribute error in imshow of an image processing result HOT 2
- ML notebook points to ML tutorial lesson that was removed
- XGboost example outdated and broken doc links
- Calling len(ddf) within the 01_datraframe.ipynb tutorial fails due to mismatched dtypes.
- Dask slides hyperlink led to "Page not found" error HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dask-examples.