Git Product home page Git Product logo

dta.pytorch's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dta.pytorch's Issues

performance issue

Hello, thank you for the paper and code.

Could you possiblly share the plot of target_ce_loss and target_accuracy for the 20 epochs of training?

I was training DTA with visda-2017 dataset with knife, bicycle, and skateboard classes only with the following command:

!python main.py --config_path ./configs/resnet101_dta_vat.json --batch_size 50 --experiment_dir logs_DTA_visda_subset --use_vat True

The only thing i changed from your paper is that I am only using 3 visda classes instead of 12, and batch size is 50 instead of 128 (due to my gpu issue).

The result is that source_ce_loss continues to decreases and source_accuracy reaches 99% by epoch 2. However, my target_ce_loss continues to increase, and target_accuracy stays around 4~6%.

About source consistency loss...

#Train 시

        # Source CE Loss
        source_features1, source_features2 = self.feature_extractor(source_inputs)
        source_logits1, source_logits2 = self.classifier(source_features1), self.classifier(source_features2)

        # Source pi model
        ce_loss = self.class_criterion(source_logits1, source_labels)
        source_consistency_loss = 2 * self.source_cnn_consistency_weight * self.source_consistency_criterion(
            source_logits2, source_logits1)
        source_loss = ce_loss + source_consistency_loss
        source_loss.backward()

에서 source_logit1, source_logits2는 완벽하게 같지 않나요?
그럼 self.source_consistency_criterion(source_logits2, source_logits1)은 의미가 없을 것 같은데요...

Problem when saving and loading models

Hi,

First things first, thank you for writing and sharing such well structured code. I just wanted to bring one issue to your attention: the code as such has problems when one wants to save his model for later use. Indeed, your overwrite the load_state_dict() method (to properly recover weights from imagenet pretraining I guess), but use the same overwritten function to recover weights from a given checkpoint :

    if args.classifier_ckpt_path:
        print("Load class classifier from {}".format(args.classifier_ckpt_path))
        ckpt = torch.load(args.classifier_ckpt_path)
        class_classifier.load_state_dict(ckpt['classifier_state_dict'])

Due to the filtering conditions in your load_state_dict() defined method, this doesn't properly recover the weights from the previous network. In the case of a given checkpoint, you should directly use the pytorch load_state_dict() method (I bring this to your attention because it doesn't actually raise any error and is not absolutely obvious when looking at the training statistics of the recovered net). Thanks again for youR work :)

Advice for hyper-parameter tuning

Hi,

Thanks for this repository, it is very well-structured and really easy to follow. I have a question about general advice for hyper-parameter tuning:

I've modified the code to run a 1D Resnet34 on time series data; right now the model is performing very well on the source distribution, reaching a max accuracy of 95.6, but achieves only max accuracy of 4.3 on the target distribution. I've noticed that as it fits to the source distribution, the accuracy on the target distribution tends to decrease.

Would you have any advice for tuning the hyperparameters? It seems like it is overfitting on the source distribution and not learning any shared features.

Thanks again!

Test images used in the validation set?

Hi,

Thanks for such a well structured code. I will apologize if my question is too naive or silly, but I am just getting started in the field of domain adaptation.

Looking at the function dataloaders_factory in the datasets/__init__.py file (lines 21to 41), it is clear that a subset (batch_size * 5 sized) of target unlabeled dataset is used as the validation set. This seems to violate machine learning 101 that you should never mix test set with train or validation sets (assuming target dataset is the test set). Am I missing something here?

paper idea

SAdD is used to maximize the divergence between a model’s prediction
and ground truth label.
our goal is to enforce the cluster assumption on target data by minimizing the divergence between predictions.
there have conflict?

About the parameters of Supplementart Material A

Hello professor. Thank you for contributing a so great work.
Here I have a tiny proble~~ :)
In your supplementary material of DTA,I saw the experiment of SVHN→MNIST and you used backbone of 9 Conv+1 FC, could you kindly tell me that what's the hyperparameters of the 9 Conv & 1 FC or you might write it in your code?
Thank for your help in advance~~
avatar

About supplementary material A for DTA

Thank professors for doing a so great work which is published on ICCV~:)

In the material A article, you professor used the 9CONV+1FC backbone for SVHN-->MNIST.

Here the problem which puzzles me is that what is the feature extractor and classifier when I used the 9CONV+1FC backbone? I tried to use 9CONV as the feature extractor and 1FC as the classifier. But it has a poor performance.
So could professor kindly tell me how should I choose the feature extractor and classifier?

Thank you in advance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.