Git Product home page Git Product logo

rcig's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

rcig's Issues

Pytorch Implementation

Do you plan to release a pytorch implementation of this code?

I am trying to replicate this as closely as possible in pytorch, using the same code for everything and making my own replicas; however, this only gets ~27% accuracy with conv. For loss functions, are you just using MSE? I see the l2 code; however, you just set l2 to 0. I am not finding where this makes a difference.

All I actually need is just some train(my_model) and a working pytorch test_loader that evaluates this, but there is a lot of bloat in this code.

Here is my implementation. I have no clue what could be causing this. I use your data, so ZCA should already be done. I use adam optimizer instead; however, this should drop the accuracy by that much. Everything outside of this code is your code in jax.

 class CoresetDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx].transpose((2, 0, 1))
        label = self.labels[idx]
        return torch.tensor(image, dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32)



class TensorFlowToTorchDataset(Dataset):
    def __init__(self, tf_dataset):
        images = []
        labels = []
        
        for image, label in tf_dataset.as_numpy_iterator(): # <- Add this line
            images.append(image)
            labels.append(label)

        self.images = np.concatenate(images, axis=0)
        self.labels = np.concatenate(labels, axis=0)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image, label = self.images[idx], self.labels[idx]
        return torch.tensor(image.transpose((2, 0, 1)), dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32)


def torch_eval(dataset_name='cifar10', data_path=None, zca_path=None, train_log=None, train_img=None, width=128, depth=3, normalization='identity', eval_lr=0.0001, random_seed=0, message='eval_log', output_dir=None, max_cycles=1000, config_path=None, checkpoint_path=None, save_name='eval_result', log_dir=None, eval_arch='resnet', models_to_test=5):
    # --------------------------------------
    # Setup
    # --------------------------------------
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

    if output_dir is None:
        output_dir = os.path.dirname(checkpoint_path)

    if log_dir is None:
        log_dir = output_dir

    logging.use_absl_handler()

    logging.get_absl_handler().use_absl_log_file('{}, {}'.format(int(time.time()), message), './{}/'.format(log_dir))
    absl.flags.FLAGS.mark_as_parsed() 
    logging.set_verbosity('info')
    
    logging.info('\n\n\n{}\n\n\n'.format(message))
    
    config = get_config()
    #config.kernel.batch_size = 256
    config.random_seed = random_seed
    config.train_log = train_log if train_log else 'train_log'
    config.train_img = train_img if train_img else 'train_img'


    config.dataset.data_path = data_path if data_path else 'data/tensorflow_datasets'
    config.dataset.zca_path = zca_path if zca_path else 'data/zca'
    config.dataset.name = dataset_name

    (ds_train, ds_test), preprocess_op, rev_preprocess_op, proto_scale = get_dataset(config.dataset)
    
    y_transform = lambda y: tf.one_hot(y, config.dataset.num_classes, on_value=1 - 1 / config.dataset.num_classes,
                                    off_value=-1 / config.dataset.num_classes)
    ds_train = configure_dataloader(ds_train, batch_size=config.kernel.batch_size, y_transform=y_transform,
                                        train=True, shuffle=True)
    ds_test = configure_dataloader(ds_test, batch_size=config.kernel.eval_batch_size, y_transform=y_transform,
                                   train=False, shuffle=False)
    
    num_classes = config.dataset.num_classes

    if config.dataset.img_shape[0] in [28, 32]:
        depth = 3
    elif config.dataset.img_shape[0] == 64:
        depth = 4
    elif config.dataset.img_shape[0] == 128:
        depth = 5
    else:
        raise Exception('Invalid resolution for the dataset')


    loaded_checkpoint = checkpoints.restore_checkpoint(f'./{checkpoint_path}', None)
    coreset_images = loaded_checkpoint['ema_average']['x_proto']
    coreset_labels = loaded_checkpoint['ema_average']['y_proto']
    coreset_dataset = CoresetDataset(coreset_images, coreset_labels)
    train_loader = DataLoader(coreset_dataset, batch_size=config.kernel.batch_size, shuffle=True)
    
    # Get the length of the TensorFlow dataset
    # Pass the whole TensorFlow dataset
    torch_ds_test = TensorFlowToTorchDataset(ds_test)
    # Pass the PyTorch dataset to the DataLoader
    test_loader = DataLoader(torch_ds_test, batch_size=config.kernel.eval_batch_size, shuffle=False)

    
    if eval_arch == 'conv':
        model = ConvNet(3, 10, 64, 3, 'relu', 'batchnorm', 'maxpooling').to(device)
        
        #Conv(use_softplus = False, beta = 20., num_classes = num_classes, width = width, depth = depth, normalization = normalization)
    elif eval_arch == 'resnet':
        model = resnet18()
        model.fc = nn.Linear(512,10)
        model = model.to(device)

    # Replace the optimizer initialization with PyTorch's optimizer
    optimizer = Adam(model.parameters(), lr=eval_lr)

    # Replace the learning rate scheduler with PyTorch's scheduler
    num_online_eval_updates = 1000 if coreset_images.shape[0] == 10 else 2000
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_online_eval_updates, eta_min=0.01)
    criterion = nn.MSELoss() #torch.nn.BCEWithLogitsLoss()

    # Training loop should be modified to work with PyTorch
    for epoch in range(max_cycles):
        # Training
        model.train()
        for batch_idx, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            #loss, _ = torch_get_training_loss_l2(model, images, labels, l2=0.0)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Update the learning rate
        lr_scheduler.step()
        torch.cuda.empty_cache()

    # Evaluation
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for i, (images, labels) in enumerate(test_loader): 
            images, labels = images.to(device), labels.to(device)
            test_output = model(images)

            
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            true_y = labels

            
            correct += (pred_y == true_y.argmax(dim=1)).sum().item()
            total += labels.size(0)
        accuracy = correct / total
        print('Accuracy: ', accuracy)

    return model

and this is my conv class

class ConvNet(nn.Module):
    def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)):
        super(ConvNet, self).__init__()

        self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
        num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
        self.classifier = nn.Linear(num_feat, num_classes)

    def forward(self, x):
        # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device())
        out = self.features(x)
        out = out.reshape(out.size(0), -1) #was view
        out = self.classifier(out)
        return out

    def _get_activation(self, net_act):
        if net_act == 'sigmoid':
            return nn.Sigmoid()
        elif net_act == 'relu':
            return nn.ReLU(inplace=True)
        elif net_act == 'leakyrelu':
            return nn.LeakyReLU(negative_slope=0.01)
        else:
            exit('unknown activation function: %s'%net_act)

    def _get_pooling(self, net_pooling):
        if net_pooling == 'maxpooling':
            return nn.MaxPool2d(kernel_size=2, stride=2)
        elif net_pooling == 'avgpooling':
            return nn.AvgPool2d(kernel_size=2, stride=2)
        elif net_pooling == 'none':
            return None
        else:
            exit('unknown net_pooling: %s'%net_pooling)

    def _get_normlayer(self, net_norm, shape_feat):
        # shape_feat = (c*h*w)
        if net_norm == 'batchnorm':
            return nn.BatchNorm2d(shape_feat[0], affine=True)
        elif net_norm == 'layernorm':
            return nn.LayerNorm(shape_feat, elementwise_affine=True)
        elif net_norm == 'instancenorm':
            return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
        elif net_norm == 'groupnorm':
            return nn.GroupNorm(4, shape_feat[0], affine=True)
        elif net_norm == 'none':
            return None
        else:
            exit('unknown net_norm: %s'%net_norm)

    def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
        layers = []
        in_channels = channel
        if im_size[0] == 28:
            im_size = (32, 32)
        shape_feat = [in_channels, im_size[0], im_size[1]]
        for d in range(net_depth):
            layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
            shape_feat[0] = net_width
            if net_norm != 'none':
                layers += [self._get_normlayer(net_norm, shape_feat)]
            layers += [self._get_activation(net_act)]
            in_channels = net_width
            if net_pooling != 'none':
                layers += [self._get_pooling(net_pooling)]
                shape_feat[1] //= 2
                shape_feat[2] //= 2


        return nn.Sequential(*layers), shape_feat

environment dependency

Dear author,
Could you please provide the requirements.txt or other files indicating the dependency.
I try to creat the conda environment myself and the dependency is as follows:
# Name Version Build Channel _libgcc_mutex 0.1 main _openmp_mutex 5.1 1_gnu absl-py 2.1.0 pypi_0 pypi array-record 0.5.1 pypi_0 pypi astunparse 1.6.3 pypi_0 pypi blas 1.0 mkl bzip2 1.0.8 h5eee18b_6 ca-certificates 2024.3.11 h06a4308_0 certifi 2024.2.2 py39h06a4308_0 chardet 5.2.0 pypi_0 pypi charset-normalizer 3.3.2 pypi_0 pypi chex 0.1.86 pypi_0 pypi click 8.1.7 pypi_0 pypi clu 0.0.12 pypi_0 pypi contextlib2 21.6.0 pypi_0 pypi cuda 11.6.1 0 nvidia cuda-cccl 11.6.55 hf6102b2_0 nvidia cuda-command-line-tools 11.6.2 0 nvidia cuda-compiler 11.6.2 0 nvidia cuda-cudart 11.6.55 he381448_0 nvidia cuda-cudart-dev 11.6.55 h42ad0f4_0 nvidia cuda-cuobjdump 11.6.124 h2eeebcb_0 nvidia cuda-cupti 11.6.124 h86345e5_0 nvidia cuda-cuxxfilt 11.6.124 hecbf4f6_0 nvidia cuda-driver-dev 11.6.55 0 nvidia cuda-gdb 12.4.127 0 nvidia cuda-libraries 11.6.1 0 nvidia cuda-libraries-dev 11.6.1 0 nvidia cuda-memcheck 11.8.86 0 nvidia cuda-nsight 12.4.127 0 nvidia cuda-nsight-compute 12.4.1 0 nvidia cuda-nvcc 11.6.124 hbba6d2d_0 nvidia cuda-nvdisasm 12.4.127 0 nvidia cuda-nvml-dev 11.6.55 haa9ef22_0 nvidia cuda-nvprof 12.4.127 0 nvidia cuda-nvprune 11.6.124 he22ec0a_0 nvidia cuda-nvrtc 11.6.124 h020bade_0 nvidia cuda-nvrtc-dev 11.6.124 h249d397_0 nvidia cuda-nvtx 11.6.124 h0630a44_0 nvidia cuda-nvvp 12.4.127 0 nvidia cuda-runtime 11.6.1 0 nvidia cuda-samples 11.6.101 h8efea70_0 nvidia cuda-sanitizer-api 12.4.127 0 nvidia cuda-toolkit 11.6.1 0 nvidia cuda-tools 11.6.1 0 nvidia cuda-visual-tools 11.6.1 0 nvidia dm-tree 0.1.8 pypi_0 pypi einops 0.8.0 pypi_0 pypi etils 1.5.2 pypi_0 pypi ffmpeg 4.3 hf484d3e_0 pytorch fire 0.6.0 pypi_0 pypi flatbuffers 24.3.25 pypi_0 pypi flax 0.8.3 pypi_0 pypi freetype 2.12.1 h4a9f257_0 fsspec 2024.3.1 pypi_0 pypi gast 0.5.4 pypi_0 pypi gds-tools 1.9.1.3 0 nvidia gmp 6.2.1 h295c915_3 gnutls 3.6.15 he1e5248_0 google-pasta 0.2.0 pypi_0 pypi grpcio 1.63.0 pypi_0 pypi h5py 3.11.0 pypi_0 pypi idna 3.7 py39h06a4308_0 importlib-metadata 7.1.0 pypi_0 pypi importlib-resources 6.4.0 pypi_0 pypi intel-openmp 2023.1.0 hdb19cb5_46306 jax 0.4.26 pypi_0 pypi jaxlib 0.4.26+cuda12.cudnn89 pypi_0 pypi jpeg 9e h5eee18b_1 keras 3.3.3 pypi_0 pypi lame 3.100 h7b6447c_0 lcms2 2.12 h3be6417_0 ld_impl_linux-64 2.38 h1181459_1 lerc 3.0 h295c915_0 libclang 18.1.1 pypi_0 pypi libcublas 11.9.2.110 h5e84587_0 nvidia libcublas-dev 11.9.2.110 h5c901ab_0 nvidia libcufft 10.7.1.112 hf425ae0_0 nvidia libcufft-dev 10.7.1.112 ha5ce4c0_0 nvidia libcufile 1.9.1.3 0 nvidia libcufile-dev 1.9.1.3 0 nvidia libcurand 10.3.5.147 0 nvidia libcurand-dev 10.3.5.147 0 nvidia libcusolver 11.3.4.124 h33c3c4e_0 nvidia libcusparse 11.7.2.124 h7538f96_0 nvidia libcusparse-dev 11.7.2.124 hbbe9722_0 nvidia libdeflate 1.17 h5eee18b_1 libffi 3.4.4 h6a678d5_1 libgcc-ng 11.2.0 h1234567_1 libgomp 11.2.0 h1234567_1 libiconv 1.16 h5eee18b_3 libidn2 2.3.4 h5eee18b_0 libnpp 11.6.3.124 hd2722f0_0 nvidia libnpp-dev 11.6.3.124 h3c42840_0 nvidia libnvjpeg 11.6.2.124 hd473ad6_0 nvidia libnvjpeg-dev 11.6.2.124 hb5906b9_0 nvidia libpng 1.6.39 h5eee18b_0 libstdcxx-ng 11.2.0 h1234567_1 libtasn1 4.19.0 h5eee18b_0 libtiff 4.5.1 h6a678d5_0 libunistring 0.9.10 h27cfd23_0 libwebp-base 1.3.2 h5eee18b_0 lz4-c 1.9.4 h6a678d5_1 markdown 3.6 pypi_0 pypi markdown-it-py 3.0.0 pypi_0 pypi markupsafe 2.1.5 pypi_0 pypi mdurl 0.1.2 pypi_0 pypi mkl 2023.1.0 h213fc3f_46344 mkl-service 2.4.0 py39h5eee18b_1 mkl_fft 1.3.8 py39h5eee18b_0 mkl_random 1.2.4 py39hdb19cb5_0 ml-collections 0.1.1 pypi_0 pypi ml-dtypes 0.3.2 pypi_0 pypi msgpack 1.0.8 pypi_0 pypi namex 0.0.8 pypi_0 pypi ncurses 6.4 h6a678d5_0 nest-asyncio 1.6.0 pypi_0 pypi nettle 3.7.3 hbbd107a_1 nsight-compute 2024.1.1.4 0 nvidia numpy 1.26.4 py39h5f9d8c6_0 numpy-base 1.26.4 py39hb5e798b_0 nvidia-cublas-cu12 12.3.4.1 pypi_0 pypi nvidia-cuda-cupti-cu12 12.3.101 pypi_0 pypi nvidia-cuda-nvcc-cu12 12.3.107 pypi_0 pypi nvidia-cuda-nvrtc-cu12 12.3.107 pypi_0 pypi nvidia-cuda-runtime-cu12 12.3.101 pypi_0 pypi nvidia-cudnn-cu12 8.9.7.29 pypi_0 pypi nvidia-cufft-cu12 11.0.12.1 pypi_0 pypi nvidia-curand-cu12 10.3.4.107 pypi_0 pypi nvidia-cusolver-cu12 11.5.4.101 pypi_0 pypi nvidia-cusparse-cu12 12.2.0.103 pypi_0 pypi nvidia-nccl-cu12 2.19.3 pypi_0 pypi nvidia-nvjitlink-cu12 12.3.101 pypi_0 pypi openh264 2.1.1 h4ff587b_0 openjpeg 2.4.0 h3ad879b_0 openssl 3.0.13 h7f8727e_1 opt-einsum 3.3.0 pypi_0 pypi optax 0.2.2 pypi_0 pypi optree 0.11.0 pypi_0 pypi orbax-checkpoint 0.5.10 pypi_0 pypi packaging 24.0 pypi_0 pypi pillow 10.3.0 py39h5eee18b_0 pip 23.3.1 py39h06a4308_0 promise 2.3 pypi_0 pypi protobuf 3.20.3 pypi_0 pypi psutil 5.9.8 pypi_0 pypi pygments 2.17.2 pypi_0 pypi python 3.9.19 h955ad1f_0 pytorch 1.13.1 py3.9_cuda11.6_cudnn8.3.2_0 pytorch pytorch-cuda 11.6 h867d48c_1 pytorch pytorch-mutex 1.0 cuda pytorch pyyaml 6.0.1 pypi_0 pypi readline 8.2 h5eee18b_0 requests 2.31.0 py39h06a4308_1 rich 13.7.1 pypi_0 pypi scipy 1.13.0 pypi_0 pypi setuptools 68.2.2 py39h06a4308_0 six 1.16.0 pypi_0 pypi sqlite 3.45.3 h5eee18b_0 tbb 2021.8.0 hdb19cb5_0 tensorboard 2.16.2 pypi_0 pypi tensorboard-data-server 0.7.2 pypi_0 pypi tensorflow 2.16.1 pypi_0 pypi tensorflow-datasets 4.9.3 pypi_0 pypi tensorflow-io-gcs-filesystem 0.37.0 pypi_0 pypi tensorflow-metadata 1.15.0 pypi_0 pypi tensorstore 0.1.58 pypi_0 pypi termcolor 2.4.0 pypi_0 pypi tk 8.6.12 h1ccaba5_0 toml 0.10.2 pypi_0 pypi toolz 0.12.1 pypi_0 pypi torchaudio 0.13.1 py39_cu116 pytorch torchvision 0.14.1 py39_cu116 pytorch tqdm 4.66.2 pypi_0 pypi typing-extensions 4.11.0 pypi_0 pypi typing_extensions 4.9.0 py39h06a4308_1 tzdata 2024a h04d1e81_0 urllib3 2.2.1 pypi_0 pypi werkzeug 3.0.2 pypi_0 pypi wheel 0.41.2 py39h06a4308_0 wrapt 1.16.0 pypi_0 pypi xz 5.4.6 h5eee18b_1 zipp 3.18.1 pypi_0 pypi zlib 1.2.13 h5eee18b_1 zstd 1.5.5 hc292b87_1
But when I run the code, it encountered the jax error as follows:
`jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/zlihm/yfp/RCIG/distill_dataset.py", line 282, in
fire.Fire(main)
File "/home/zlihm/anaconda3/envs/tf/lib/python3.9/site-packages/fire/core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/home/zlihm/anaconda3/envs/tf/lib/python3.9/site-packages/fire/core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/home/zlihm/anaconda3/envs/tf/lib/python3.9/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/home/zlihm/yfp/RCIG/distill_dataset.py", line 230, in main
inner_result, _ = algorithms.run_rcig(coreset_images, coreset_labels, model_for_train.init, model_for_train.apply, ds_train, alg_config, key, inner_learning_rate, hvp_learning_rate, lr_tune = True)
File "/home/zlihm/yfp/RCIG/algorithms.py", line 87, in run_rcig
new_train_state, key = get_new_train_state(key, alg_config.pool_learning_rate, alg_config.model_depth, alg_config.has_bn, alg_config.linearize, net_forward_apply, net_forward_init, coreset_images_init.shape, naive_loss = alg_config.naive_loss)
File "/home/zlihm/yfp/RCIG/algorithms.py", line 50, in get_new_train_state
new_params = new_params.unfreeze()
AttributeError: 'dict' object has no attribute 'unfreeze'`

Will you release the source code of RCIG recently?

Hi,

Thanks for your excellent work!

When will you release the source code of RCIG? I want to apply RCIG to other domains, such as time series. So I'm wondering if you're going to release the code recently or when your paper is accepted.

Non-conv architectures do not get reported performance in paper

I am using the script in the readme, but changed the eval_arch to resnet and vgg. These do not get reported accuracy from the paper.

python3 eval.py --dataset_name cifar10 --checkpoint_path ./distilled_images_final/0/cifar10_10/checkpoint_10000 --config_path ./configs_final/depth_3.txt --random_seed 0 --eval_arch resnet

Resnet achieved 20% and VGG gets 40%. After changing the --normalization to batch, it still gets low performance.

In the supplemental section, you state that no data augmentation gets better results for cifar10; however, your configs for cifar10 use data augmentation. Why is this? Furthermore, what hyperparameters and I supposed to use for ResNet on cifar10? I am assuming its still depth_3.txt

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.