Hello, thanks for sharing your code out, I have a question regarding your code
`import torch
from sparse_utils import Pruner
train_loader = ...
epochs = ...
model = ...
create a Pruner class instance
pruner = Pruner(model, device=..., final_rate=..., nbatches=..., epochs=...)
optimizer = ...
loss_fn = ...
for epoch in range(epochs):
for data, target in train_loader:
# update the pruning threshold based on the iteration number and the scheduler used
pruner.update_thresh()
output = model(data)
loss = loss_fn(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update the pruning threshold after last step of the optimizer
pruner.update_thresh(end_of_batch=True)
finalize sparse model
pruner.desparsify()`
To save the sparsified model, do I just need to save the model
?