rcig's People
rcig's Issues
Pytorch Implementation
Do you plan to release a pytorch implementation of this code?
I am trying to replicate this as closely as possible in pytorch, using the same code for everything and making my own replicas; however, this only gets ~27% accuracy with conv. For loss functions, are you just using MSE? I see the l2 code; however, you just set l2 to 0. I am not finding where this makes a difference.
All I actually need is just some train(my_model) and a working pytorch test_loader that evaluates this, but there is a lot of bloat in this code.
Here is my implementation. I have no clue what could be causing this. I use your data, so ZCA should already be done. I use adam optimizer instead; however, this should drop the accuracy by that much. Everything outside of this code is your code in jax.
class CoresetDataset(Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx].transpose((2, 0, 1))
label = self.labels[idx]
return torch.tensor(image, dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32)
class TensorFlowToTorchDataset(Dataset):
def __init__(self, tf_dataset):
images = []
labels = []
for image, label in tf_dataset.as_numpy_iterator(): # <- Add this line
images.append(image)
labels.append(label)
self.images = np.concatenate(images, axis=0)
self.labels = np.concatenate(labels, axis=0)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image, label = self.images[idx], self.labels[idx]
return torch.tensor(image.transpose((2, 0, 1)), dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32)
def torch_eval(dataset_name='cifar10', data_path=None, zca_path=None, train_log=None, train_img=None, width=128, depth=3, normalization='identity', eval_lr=0.0001, random_seed=0, message='eval_log', output_dir=None, max_cycles=1000, config_path=None, checkpoint_path=None, save_name='eval_result', log_dir=None, eval_arch='resnet', models_to_test=5):
# --------------------------------------
# Setup
# --------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if output_dir is None:
output_dir = os.path.dirname(checkpoint_path)
if log_dir is None:
log_dir = output_dir
logging.use_absl_handler()
logging.get_absl_handler().use_absl_log_file('{}, {}'.format(int(time.time()), message), './{}/'.format(log_dir))
absl.flags.FLAGS.mark_as_parsed()
logging.set_verbosity('info')
logging.info('\n\n\n{}\n\n\n'.format(message))
config = get_config()
#config.kernel.batch_size = 256
config.random_seed = random_seed
config.train_log = train_log if train_log else 'train_log'
config.train_img = train_img if train_img else 'train_img'
config.dataset.data_path = data_path if data_path else 'data/tensorflow_datasets'
config.dataset.zca_path = zca_path if zca_path else 'data/zca'
config.dataset.name = dataset_name
(ds_train, ds_test), preprocess_op, rev_preprocess_op, proto_scale = get_dataset(config.dataset)
y_transform = lambda y: tf.one_hot(y, config.dataset.num_classes, on_value=1 - 1 / config.dataset.num_classes,
off_value=-1 / config.dataset.num_classes)
ds_train = configure_dataloader(ds_train, batch_size=config.kernel.batch_size, y_transform=y_transform,
train=True, shuffle=True)
ds_test = configure_dataloader(ds_test, batch_size=config.kernel.eval_batch_size, y_transform=y_transform,
train=False, shuffle=False)
num_classes = config.dataset.num_classes
if config.dataset.img_shape[0] in [28, 32]:
depth = 3
elif config.dataset.img_shape[0] == 64:
depth = 4
elif config.dataset.img_shape[0] == 128:
depth = 5
else:
raise Exception('Invalid resolution for the dataset')
loaded_checkpoint = checkpoints.restore_checkpoint(f'./{checkpoint_path}', None)
coreset_images = loaded_checkpoint['ema_average']['x_proto']
coreset_labels = loaded_checkpoint['ema_average']['y_proto']
coreset_dataset = CoresetDataset(coreset_images, coreset_labels)
train_loader = DataLoader(coreset_dataset, batch_size=config.kernel.batch_size, shuffle=True)
# Get the length of the TensorFlow dataset
# Pass the whole TensorFlow dataset
torch_ds_test = TensorFlowToTorchDataset(ds_test)
# Pass the PyTorch dataset to the DataLoader
test_loader = DataLoader(torch_ds_test, batch_size=config.kernel.eval_batch_size, shuffle=False)
if eval_arch == 'conv':
model = ConvNet(3, 10, 64, 3, 'relu', 'batchnorm', 'maxpooling').to(device)
#Conv(use_softplus = False, beta = 20., num_classes = num_classes, width = width, depth = depth, normalization = normalization)
elif eval_arch == 'resnet':
model = resnet18()
model.fc = nn.Linear(512,10)
model = model.to(device)
# Replace the optimizer initialization with PyTorch's optimizer
optimizer = Adam(model.parameters(), lr=eval_lr)
# Replace the learning rate scheduler with PyTorch's scheduler
num_online_eval_updates = 1000 if coreset_images.shape[0] == 10 else 2000
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_online_eval_updates, eta_min=0.01)
criterion = nn.MSELoss() #torch.nn.BCEWithLogitsLoss()
# Training loop should be modified to work with PyTorch
for epoch in range(max_cycles):
# Training
model.train()
for batch_idx, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
#loss, _ = torch_get_training_loss_l2(model, images, labels, l2=0.0)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update the learning rate
lr_scheduler.step()
torch.cuda.empty_cache()
# Evaluation
model.eval()
with torch.no_grad():
correct = 0
total = 0
for i, (images, labels) in enumerate(test_loader):
images, labels = images.to(device), labels.to(device)
test_output = model(images)
pred_y = torch.max(test_output, 1)[1].data.squeeze()
true_y = labels
correct += (pred_y == true_y.argmax(dim=1)).sum().item()
total += labels.size(0)
accuracy = correct / total
print('Accuracy: ', accuracy)
return model
and this is my conv class
class ConvNet(nn.Module):
def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)):
super(ConvNet, self).__init__()
self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
self.classifier = nn.Linear(num_feat, num_classes)
def forward(self, x):
# print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device())
out = self.features(x)
out = out.reshape(out.size(0), -1) #was view
out = self.classifier(out)
return out
def _get_activation(self, net_act):
if net_act == 'sigmoid':
return nn.Sigmoid()
elif net_act == 'relu':
return nn.ReLU(inplace=True)
elif net_act == 'leakyrelu':
return nn.LeakyReLU(negative_slope=0.01)
else:
exit('unknown activation function: %s'%net_act)
def _get_pooling(self, net_pooling):
if net_pooling == 'maxpooling':
return nn.MaxPool2d(kernel_size=2, stride=2)
elif net_pooling == 'avgpooling':
return nn.AvgPool2d(kernel_size=2, stride=2)
elif net_pooling == 'none':
return None
else:
exit('unknown net_pooling: %s'%net_pooling)
def _get_normlayer(self, net_norm, shape_feat):
# shape_feat = (c*h*w)
if net_norm == 'batchnorm':
return nn.BatchNorm2d(shape_feat[0], affine=True)
elif net_norm == 'layernorm':
return nn.LayerNorm(shape_feat, elementwise_affine=True)
elif net_norm == 'instancenorm':
return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
elif net_norm == 'groupnorm':
return nn.GroupNorm(4, shape_feat[0], affine=True)
elif net_norm == 'none':
return None
else:
exit('unknown net_norm: %s'%net_norm)
def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
layers = []
in_channels = channel
if im_size[0] == 28:
im_size = (32, 32)
shape_feat = [in_channels, im_size[0], im_size[1]]
for d in range(net_depth):
layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
shape_feat[0] = net_width
if net_norm != 'none':
layers += [self._get_normlayer(net_norm, shape_feat)]
layers += [self._get_activation(net_act)]
in_channels = net_width
if net_pooling != 'none':
layers += [self._get_pooling(net_pooling)]
shape_feat[1] //= 2
shape_feat[2] //= 2
return nn.Sequential(*layers), shape_feat
How can we go about evaluating pre-distilled images this on Pytorch?
I have not had luck using the pre-distilled images on general out of the box implementations of 128 wide conv models in pytorch.
environment dependency
Dear author,
Could you please provide the requirements.txt or other files indicating the dependency.
I try to creat the conda environment myself and the dependency is as follows:
# Name Version Build Channel _libgcc_mutex 0.1 main _openmp_mutex 5.1 1_gnu absl-py 2.1.0 pypi_0 pypi array-record 0.5.1 pypi_0 pypi astunparse 1.6.3 pypi_0 pypi blas 1.0 mkl bzip2 1.0.8 h5eee18b_6 ca-certificates 2024.3.11 h06a4308_0 certifi 2024.2.2 py39h06a4308_0 chardet 5.2.0 pypi_0 pypi charset-normalizer 3.3.2 pypi_0 pypi chex 0.1.86 pypi_0 pypi click 8.1.7 pypi_0 pypi clu 0.0.12 pypi_0 pypi contextlib2 21.6.0 pypi_0 pypi cuda 11.6.1 0 nvidia cuda-cccl 11.6.55 hf6102b2_0 nvidia cuda-command-line-tools 11.6.2 0 nvidia cuda-compiler 11.6.2 0 nvidia cuda-cudart 11.6.55 he381448_0 nvidia cuda-cudart-dev 11.6.55 h42ad0f4_0 nvidia cuda-cuobjdump 11.6.124 h2eeebcb_0 nvidia cuda-cupti 11.6.124 h86345e5_0 nvidia cuda-cuxxfilt 11.6.124 hecbf4f6_0 nvidia cuda-driver-dev 11.6.55 0 nvidia cuda-gdb 12.4.127 0 nvidia cuda-libraries 11.6.1 0 nvidia cuda-libraries-dev 11.6.1 0 nvidia cuda-memcheck 11.8.86 0 nvidia cuda-nsight 12.4.127 0 nvidia cuda-nsight-compute 12.4.1 0 nvidia cuda-nvcc 11.6.124 hbba6d2d_0 nvidia cuda-nvdisasm 12.4.127 0 nvidia cuda-nvml-dev 11.6.55 haa9ef22_0 nvidia cuda-nvprof 12.4.127 0 nvidia cuda-nvprune 11.6.124 he22ec0a_0 nvidia cuda-nvrtc 11.6.124 h020bade_0 nvidia cuda-nvrtc-dev 11.6.124 h249d397_0 nvidia cuda-nvtx 11.6.124 h0630a44_0 nvidia cuda-nvvp 12.4.127 0 nvidia cuda-runtime 11.6.1 0 nvidia cuda-samples 11.6.101 h8efea70_0 nvidia cuda-sanitizer-api 12.4.127 0 nvidia cuda-toolkit 11.6.1 0 nvidia cuda-tools 11.6.1 0 nvidia cuda-visual-tools 11.6.1 0 nvidia dm-tree 0.1.8 pypi_0 pypi einops 0.8.0 pypi_0 pypi etils 1.5.2 pypi_0 pypi ffmpeg 4.3 hf484d3e_0 pytorch fire 0.6.0 pypi_0 pypi flatbuffers 24.3.25 pypi_0 pypi flax 0.8.3 pypi_0 pypi freetype 2.12.1 h4a9f257_0 fsspec 2024.3.1 pypi_0 pypi gast 0.5.4 pypi_0 pypi gds-tools 1.9.1.3 0 nvidia gmp 6.2.1 h295c915_3 gnutls 3.6.15 he1e5248_0 google-pasta 0.2.0 pypi_0 pypi grpcio 1.63.0 pypi_0 pypi h5py 3.11.0 pypi_0 pypi idna 3.7 py39h06a4308_0 importlib-metadata 7.1.0 pypi_0 pypi importlib-resources 6.4.0 pypi_0 pypi intel-openmp 2023.1.0 hdb19cb5_46306 jax 0.4.26 pypi_0 pypi jaxlib 0.4.26+cuda12.cudnn89 pypi_0 pypi jpeg 9e h5eee18b_1 keras 3.3.3 pypi_0 pypi lame 3.100 h7b6447c_0 lcms2 2.12 h3be6417_0 ld_impl_linux-64 2.38 h1181459_1 lerc 3.0 h295c915_0 libclang 18.1.1 pypi_0 pypi libcublas 11.9.2.110 h5e84587_0 nvidia libcublas-dev 11.9.2.110 h5c901ab_0 nvidia libcufft 10.7.1.112 hf425ae0_0 nvidia libcufft-dev 10.7.1.112 ha5ce4c0_0 nvidia libcufile 1.9.1.3 0 nvidia libcufile-dev 1.9.1.3 0 nvidia libcurand 10.3.5.147 0 nvidia libcurand-dev 10.3.5.147 0 nvidia libcusolver 11.3.4.124 h33c3c4e_0 nvidia libcusparse 11.7.2.124 h7538f96_0 nvidia libcusparse-dev 11.7.2.124 hbbe9722_0 nvidia libdeflate 1.17 h5eee18b_1 libffi 3.4.4 h6a678d5_1 libgcc-ng 11.2.0 h1234567_1 libgomp 11.2.0 h1234567_1 libiconv 1.16 h5eee18b_3 libidn2 2.3.4 h5eee18b_0 libnpp 11.6.3.124 hd2722f0_0 nvidia libnpp-dev 11.6.3.124 h3c42840_0 nvidia libnvjpeg 11.6.2.124 hd473ad6_0 nvidia libnvjpeg-dev 11.6.2.124 hb5906b9_0 nvidia libpng 1.6.39 h5eee18b_0 libstdcxx-ng 11.2.0 h1234567_1 libtasn1 4.19.0 h5eee18b_0 libtiff 4.5.1 h6a678d5_0 libunistring 0.9.10 h27cfd23_0 libwebp-base 1.3.2 h5eee18b_0 lz4-c 1.9.4 h6a678d5_1 markdown 3.6 pypi_0 pypi markdown-it-py 3.0.0 pypi_0 pypi markupsafe 2.1.5 pypi_0 pypi mdurl 0.1.2 pypi_0 pypi mkl 2023.1.0 h213fc3f_46344 mkl-service 2.4.0 py39h5eee18b_1 mkl_fft 1.3.8 py39h5eee18b_0 mkl_random 1.2.4 py39hdb19cb5_0 ml-collections 0.1.1 pypi_0 pypi ml-dtypes 0.3.2 pypi_0 pypi msgpack 1.0.8 pypi_0 pypi namex 0.0.8 pypi_0 pypi ncurses 6.4 h6a678d5_0 nest-asyncio 1.6.0 pypi_0 pypi nettle 3.7.3 hbbd107a_1 nsight-compute 2024.1.1.4 0 nvidia numpy 1.26.4 py39h5f9d8c6_0 numpy-base 1.26.4 py39hb5e798b_0 nvidia-cublas-cu12 12.3.4.1 pypi_0 pypi nvidia-cuda-cupti-cu12 12.3.101 pypi_0 pypi nvidia-cuda-nvcc-cu12 12.3.107 pypi_0 pypi nvidia-cuda-nvrtc-cu12 12.3.107 pypi_0 pypi nvidia-cuda-runtime-cu12 12.3.101 pypi_0 pypi nvidia-cudnn-cu12 8.9.7.29 pypi_0 pypi nvidia-cufft-cu12 11.0.12.1 pypi_0 pypi nvidia-curand-cu12 10.3.4.107 pypi_0 pypi nvidia-cusolver-cu12 11.5.4.101 pypi_0 pypi nvidia-cusparse-cu12 12.2.0.103 pypi_0 pypi nvidia-nccl-cu12 2.19.3 pypi_0 pypi nvidia-nvjitlink-cu12 12.3.101 pypi_0 pypi openh264 2.1.1 h4ff587b_0 openjpeg 2.4.0 h3ad879b_0 openssl 3.0.13 h7f8727e_1 opt-einsum 3.3.0 pypi_0 pypi optax 0.2.2 pypi_0 pypi optree 0.11.0 pypi_0 pypi orbax-checkpoint 0.5.10 pypi_0 pypi packaging 24.0 pypi_0 pypi pillow 10.3.0 py39h5eee18b_0 pip 23.3.1 py39h06a4308_0 promise 2.3 pypi_0 pypi protobuf 3.20.3 pypi_0 pypi psutil 5.9.8 pypi_0 pypi pygments 2.17.2 pypi_0 pypi python 3.9.19 h955ad1f_0 pytorch 1.13.1 py3.9_cuda11.6_cudnn8.3.2_0 pytorch pytorch-cuda 11.6 h867d48c_1 pytorch pytorch-mutex 1.0 cuda pytorch pyyaml 6.0.1 pypi_0 pypi readline 8.2 h5eee18b_0 requests 2.31.0 py39h06a4308_1 rich 13.7.1 pypi_0 pypi scipy 1.13.0 pypi_0 pypi setuptools 68.2.2 py39h06a4308_0 six 1.16.0 pypi_0 pypi sqlite 3.45.3 h5eee18b_0 tbb 2021.8.0 hdb19cb5_0 tensorboard 2.16.2 pypi_0 pypi tensorboard-data-server 0.7.2 pypi_0 pypi tensorflow 2.16.1 pypi_0 pypi tensorflow-datasets 4.9.3 pypi_0 pypi tensorflow-io-gcs-filesystem 0.37.0 pypi_0 pypi tensorflow-metadata 1.15.0 pypi_0 pypi tensorstore 0.1.58 pypi_0 pypi termcolor 2.4.0 pypi_0 pypi tk 8.6.12 h1ccaba5_0 toml 0.10.2 pypi_0 pypi toolz 0.12.1 pypi_0 pypi torchaudio 0.13.1 py39_cu116 pytorch torchvision 0.14.1 py39_cu116 pytorch tqdm 4.66.2 pypi_0 pypi typing-extensions 4.11.0 pypi_0 pypi typing_extensions 4.9.0 py39h06a4308_1 tzdata 2024a h04d1e81_0 urllib3 2.2.1 pypi_0 pypi werkzeug 3.0.2 pypi_0 pypi wheel 0.41.2 py39h06a4308_0 wrapt 1.16.0 pypi_0 pypi xz 5.4.6 h5eee18b_1 zipp 3.18.1 pypi_0 pypi zlib 1.2.13 h5eee18b_1 zstd 1.5.5 hc292b87_1
But when I run the code, it encountered the jax error as follows:
`jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/zlihm/yfp/RCIG/distill_dataset.py", line 282, in
fire.Fire(main)
File "/home/zlihm/anaconda3/envs/tf/lib/python3.9/site-packages/fire/core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/home/zlihm/anaconda3/envs/tf/lib/python3.9/site-packages/fire/core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/home/zlihm/anaconda3/envs/tf/lib/python3.9/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/home/zlihm/yfp/RCIG/distill_dataset.py", line 230, in main
inner_result, _ = algorithms.run_rcig(coreset_images, coreset_labels, model_for_train.init, model_for_train.apply, ds_train, alg_config, key, inner_learning_rate, hvp_learning_rate, lr_tune = True)
File "/home/zlihm/yfp/RCIG/algorithms.py", line 87, in run_rcig
new_train_state, key = get_new_train_state(key, alg_config.pool_learning_rate, alg_config.model_depth, alg_config.has_bn, alg_config.linearize, net_forward_apply, net_forward_init, coreset_images_init.shape, naive_loss = alg_config.naive_loss)
File "/home/zlihm/yfp/RCIG/algorithms.py", line 50, in get_new_train_state
new_params = new_params.unfreeze()
AttributeError: 'dict' object has no attribute 'unfreeze'`
Will you release the source code of RCIG recently?
Hi,
Thanks for your excellent work!
When will you release the source code of RCIG? I want to apply RCIG to other domains, such as time series. So I'm wondering if you're going to release the code recently or when your paper is accepted.
Non-conv architectures do not get reported performance in paper
I am using the script in the readme, but changed the eval_arch to resnet and vgg. These do not get reported accuracy from the paper.
python3 eval.py --dataset_name cifar10 --checkpoint_path ./distilled_images_final/0/cifar10_10/checkpoint_10000 --config_path ./configs_final/depth_3.txt --random_seed 0 --eval_arch resnet
Resnet achieved 20% and VGG gets 40%. After changing the --normalization to batch, it still gets low performance.
In the supplemental section, you state that no data augmentation gets better results for cifar10; however, your configs for cifar10 use data augmentation. Why is this? Furthermore, what hyperparameters and I supposed to use for ResNet on cifar10? I am assuming its still depth_3.txt
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.