Comments (3)
I ran into the same problem when using a different LLM. The problem you are finding is related to equation (14) of the paper MEMIT, in my case the problem I had was that the "aggregate statistic
I found two real solutions:
- Easy solution. Do not retrain the layers that are having these problems. If you go to "hparams/MEMIT/EleutherAI_gpt-j-6B.json" you'll see that the layers that are being trained are:
"layers": [ 3, 4, 5, 6, 7, 8 ],
Change the matrices that are causing these problems, if you look at the causal trace you will see that you have some freedom to choose between them. - Hard solution: Remove the rows/columns that are full of zeros, compute the inverse of the matrix, and add the rows/columns of zeros again. Note that here you will not be computing the inverse, since some columns will be zero, but it will be an approximation that do not add noise. The problem I experimented with this solution is that even removing the zero row/columns, there were still some "unimportant" coordinates that were raising the norm of my delta matrix, which is making me cautious.
To implement this go to memit_main and add these lines to the beginning:
def make_null_i(matrix, i):
new_matrix = matrix.clone()
new_matrix[:,i] = new_matrix[:,i]*0
new_matrix[i,:] = new_matrix[i,:]*0
return new_matrix
def identify_null_cols(matrix):
row_sums = matrix.clone().sum(dim=1)
zero_rows = torch.nonzero(row_sums == 0).squeeze()
return zero_rows.numel(), zero_rows.tolist()
def remove_column(matrix, i):
new_matrix = matrix.clone()
new_matrix = torch.cat((new_matrix[:i], new_matrix[i+1:]), dim=0)
new_matrix = torch.cat((new_matrix[:, :i], new_matrix[:, i+1:]), dim=1)
return new_matrix
def add_zero_column(matrix, i):
new_matrix = matrix.clone()
new_row = torch.zeros(1, matrix.shape[1], device=matrix.device, dtype=matrix.dtype)
new_col = torch.zeros(matrix.shape[0] + 1, 1, device=matrix.device, dtype=matrix.dtype)
new_matrix = torch.cat((new_matrix[:i], new_row, new_matrix[i:]), dim=0)
new_matrix = torch.cat((new_matrix[:, :i], new_col, new_matrix[:, i:]), dim=1)
return new_matrix
def compute_pseudoinverse_matrix(matrix):
n, ids = identify_null_cols(matrix)
print(f"There are {n} columns with zeros")
if n==0:
return torch.linalg.inv(matrix)
# Remove the zero columns that are causing our matrix to be singular
new_matrix = matrix.clone()
for id_ in ids[::-1]:
new_matrix = remove_column(new_matrix, id_)
# Computing inverse
new_matrix = torch.linalg.inv(new_matrix)
# Rescaling the matrix
for id_ in ids:
new_matrix = add_zero_column(new_matrix,id_)
return new_matrix
and then change the lines 196-199 to:
matrix = hparams.mom2_update_weight * cov.double().detach().cpu()+layer_ks.detach().cpu()@layer_ks.T.detach().cpu()
n_nul_cols,_ = identify_null_cols(matrix)
if n_nul_cols != 0:
adj_k = compute_pseudoinverse_matrix(matrix) @ layer_ks.detach().cpu()
else:
adj_k = torch.linalg.solve(matrix,layer_ks.detach().cpu())
- Extra possible solution? Increment number of edits.
I hope it helps. Good luck!
from memit.
I found your response to be quite valuable. Thank you very much!
from memit.
from memit.
Related Issues (17)
- Applying to other models
- Distributing the update across multiple layer HOT 1
- NotImplementedError for GPT-J-6b HOT 2
- Missing `data` folder in root directory
- Multi-GPU support for MEMIT HOT 1
- CUDA out of memory
- what is the difference between multicounterfact and counterfact? HOT 1
- Discussion: About Knowledge Editing HOT 1
- IndexError: tuple index out of range at cur_repr processing stage HOT 1
- Paraphrase prompts' format not compatible with the sample from ROME paper HOT 3
- muti-counterfact and counterfact HOT 1
- No optimization after first step HOT 1
- GPU not big enough? I'm using A5500 24GB RAM
- IndexError: tuple index out of range
- Can it work with Llama 3 / other 7b models?
- How to detemine which hparam layers
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from memit.