Git Product home page Git Product logo

Comments (9)

stsievert avatar stsievert commented on August 14, 2024

cc @muammar

In addition to this example, I'd also link to integration with a Scikit-learn wrapper for PyTorch skorch and Dask-ML's ParallelPostFit.

from dask-examples.

TomAugspurger avatar TomAugspurger commented on August 14, 2024

Should have an example ready tomorrow.

hopefully doesn't get too much into Torch or a dataset.

I think we'll want to go into some detail about torch.utils.data.Dataset, because it's not 100% straightforward how to get the data loaded onto workers. To predict for a directory of images, I had to write the following myself

import glob

from PIL import Image


def default_loader(path, fs=__builtins__):
    with fs.open(path, 'rb') as f:
        img = Image.open(f).convert("RGB")
        return img


class FileDataset(torch.utils.data.Dataset):
    def __init__(self, files, transform=None, target_transform=None,
                 classes=None,
                 loader=default_loader):
        self.files = files
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        if classes is None:
            classes = list(sorted(set(x.split(os.path.sep)[-2] for x in files)))
        else:
            classes = list(classes)
        self.classes = classes

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, index):
        filename = self.files[index]
        img = self.loader(filename)
        target = self.classes.index(filename.split(os.path.sep)[-2])
        
        if self.transform is not None:
            img = self.transform(img)
            
        if self.target_transform is not None:
            target = self.target_transform(target)
        
        return img, target

and use it as

files = glob.glob("hymenoptera_data/val/*/*.jpg")
dataset = FileDataset(files, transform=data_transforms['val'])

For s3, the usage would be FileDataset(files, ..., loader=functools.partial(default_loader, fs=s3fs.S3FileSystem(...)). As a relative newcomer to PyTorch, writing that wasn't 100% straightforward.

Things seem to be working out well after that. PyTorch models seem to (de)serialize much better than tensorflow's did last time I tried.

from dask-examples.

mrocklin avatar mrocklin commented on August 14, 2024

Do we need to use the Torch Dataset API here?

because it's not 100% straightforward how to get the data loaded onto workers

I guess my hope is that, for image data at least, we could just pass around Numpy arrays. So we might created dask.delayed objects using skimage.io.imread or something similar. (maybe like https://blog.dask.org/2019/06/20/load-image-data , but before the dask array bit)

from dask-examples.

mrocklin avatar mrocklin commented on August 14, 2024

Also, if you haven't seen it, this video is nice: https://developer.download.nvidia.com/video/gputechconf/gtc/2019/video/S9198/s9198-dask-and-v100s-for-fast-distributed-batch-scoring-of-computer-vision-workloads.mp4

from dask-examples.

TomAugspurger avatar TomAugspurger commented on August 14, 2024

from dask-examples.

mrocklin avatar mrocklin commented on August 14, 2024

from dask-examples.

stsievert avatar stsievert commented on August 14, 2024

I'm curious about inputs of Dask Arrays and outputs of model predictions too. I think PyTorch Datasets will need to play an intermediate role; at least that's what skorch uses when tracing net.py's Net.predict to net.py#L1150.

Distributed training would also be interesting of course, but my guess is that that's more of an open problem

It's also mentioned in dask/distributed#2581

from dask-examples.

AlbertDeFusco avatar AlbertDeFusco commented on August 14, 2024

Skorch looks interesting to me. Can the wrapper be used after loading the model from disk where the wrapper was not used?

I've practiced applying the dask-ml parallelpostfit wrapper on a pre-trained model and I remember having to do a few manual steps before running predictions. I need to dig up that code.

from dask-examples.

stsievert avatar stsievert commented on August 14, 2024

Can the wrapper be used after loading the model from disk where the wrapper was not used?

Yup. The underlying model is an attribute (.module_), so it's simple:

import torch
from skorch import NeuralNetClassifier

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        ...

model = Net()
#  Train model

# Save trained model using PyTorch
torch.save(model.state_dict(), "trained_model.pt")

# Use skorch later (not necessarily the training session)
sk_net = NeuralNetClassifier(Net)
sk_net.initialize()

# Load parameters saved with PyTorch
sk_net.module_.load_state_dict(torch.load("trained_model.pt"))

from dask-examples.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.