Git Product home page Git Product logo

Comments (10)

redhat12345 avatar redhat12345 commented on August 22, 2024

Here is the code for pretrained model.

# load AlexNet pre-trained model
def load_pretrained(model):
    url = 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth'
    pretrained_dict = model_zoo.load_url(url)
    model_dict = model.state_dict()

    # filter out unmatch dict and delete last fc bias, weight
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # del pretrained_dict['classifier.6.bias']
    # del pretrained_dict['classifier.6.weight']

    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

from deepcoral.

deep0learning avatar deep0learning commented on August 22, 2024

@redhat12345 Thanks for your reply. Can anyone help me please how can I train the model with out pretrained model.

from deepcoral.

SSARCandy avatar SSARCandy commented on August 22, 2024

You can remove the load_pretrained() in __main__ and CORAL loss in train()

from deepcoral.

deep0learning avatar deep0learning commented on August 22, 2024

@SSARCandy
Thank you so much for your reply. My intention is that I will use coral loss without pretrained model. How can I do that? Thanks in advanced.

from deepcoral.

SSARCandy avatar SSARCandy commented on August 22, 2024

Oh. Just remove the load_pretrained() in __main__ , then it will not load the ImageNet pretrained model

from deepcoral.

deep0learning avatar deep0learning commented on August 22, 2024

I have removed load_pretrained() in main. But got the error.

Traceback (most recent call last):
File "main.py", line 160, in
load_pretrained(model.sharedNet)
NameError: name 'load_pretrained' is not defined

from deepcoral.

SSARCandy avatar SSARCandy commented on August 22, 2024

all you have to do is remove line 143-146

from deepcoral.

deep0learning avatar deep0learning commented on August 22, 2024

@SSARCandy

Thank you so much. I want to use LeNet. But I got error. I want to do SHVN to MNIST. I have created the class lenet. Can you please tell me how can I do that? Thanks in advanced.

transform = transforms.Compose([
transforms.Resize((28,28)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])

###SVHN to MNIST

train_dataset = datasets.SVHN('SVHN', download=True, transform=transform, split='train')
valid_dataset = datasets.MNIST('MNIST', download=True, transform=transform, train=True)

source_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE[0], shuffle=True)
target_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE[1], shuffle=True)

class LeNet(nn.Module):
def init(self):
super(LeNet, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(1655, 120)
self.fc2 = nn.Linear(120, 84)
#self.fc3 = nn.Linear(84, 10)

def forward(self, x):
    out = F.relu(self.conv1(x))
    out = F.max_pool2d(out, 2)
    out = F.relu(self.conv2(out))
    out = F.max_pool2d(out, 2)
    out = out.view(out.size(0), -1)
    out = F.relu(self.fc1(out))
    out = F.relu(self.fc2(out))
    out = self.fc3(out)
    return out

I have also changed

class DeepCORAL(nn.Module):
def init(self, num_classes=1000):
super(DeepCORAL, self).init()
self.sharedNet = LeNet()
self.source_fc = nn.Linear(84, num_classes)
self.target_fc = nn.Linear(84, num_classes)

    # initialize according to CORAL paper experiment
    self.source_fc.weight.data.normal_(0, 0.005)
    self.target_fc.weight.data.normal_(0, 0.005)

from deepcoral.

redhat12345 avatar redhat12345 commented on August 22, 2024

@deep0learning

Change the input size. The input size should be 32*32

transform = transforms.Compose([
transforms.Resize((32,32)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])

from deepcoral.

deep0learning avatar deep0learning commented on August 22, 2024

@redhat12345

Thank you so much. It's working but got poor accuracy.

from deepcoral.

Related Issues (19)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.