Git Product home page Git Product logo

Comments (3)

dtamayo-nlp avatar dtamayo-nlp commented on July 17, 2024

I ran into the same problem when using a different LLM. The problem you are finding is related to equation (14) of the paper MEMIT, in my case the problem I had was that the "aggregate statistic $C_0$" had rows and columns with zeros, and even summing $K_1 K_1^T$ those rows were still zero. When you have any row of zeros, the matrix is "singular", which implies that you cannot compute its inverse. If you look at the construction of these matrices, the existence of zeros implies that there are coordinates in the hidden states unused. However, how can you solve it?

I found two real solutions:

  1. Easy solution. Do not retrain the layers that are having these problems. If you go to "hparams/MEMIT/EleutherAI_gpt-j-6B.json" you'll see that the layers that are being trained are:
    "layers": [ 3, 4, 5, 6, 7, 8 ],
    Change the matrices that are causing these problems, if you look at the causal trace you will see that you have some freedom to choose between them.
  2. Hard solution: Remove the rows/columns that are full of zeros, compute the inverse of the matrix, and add the rows/columns of zeros again. Note that here you will not be computing the inverse, since some columns will be zero, but it will be an approximation that do not add noise. The problem I experimented with this solution is that even removing the zero row/columns, there were still some "unimportant" coordinates that were raising the norm of my delta matrix, which is making me cautious.
    To implement this go to memit_main and add these lines to the beginning:
def make_null_i(matrix, i):
    new_matrix = matrix.clone()
    new_matrix[:,i] = new_matrix[:,i]*0
    new_matrix[i,:] = new_matrix[i,:]*0
    return new_matrix
def identify_null_cols(matrix):
    row_sums = matrix.clone().sum(dim=1)
    zero_rows = torch.nonzero(row_sums == 0).squeeze()
    return zero_rows.numel(), zero_rows.tolist()

def remove_column(matrix, i):
    new_matrix = matrix.clone()
    new_matrix = torch.cat((new_matrix[:i], new_matrix[i+1:]), dim=0)
    new_matrix = torch.cat((new_matrix[:, :i], new_matrix[:, i+1:]), dim=1)
    return new_matrix

def add_zero_column(matrix, i):
    new_matrix = matrix.clone()
    new_row = torch.zeros(1, matrix.shape[1], device=matrix.device, dtype=matrix.dtype)
    new_col = torch.zeros(matrix.shape[0] + 1, 1, device=matrix.device, dtype=matrix.dtype)
    new_matrix = torch.cat((new_matrix[:i], new_row, new_matrix[i:]), dim=0)
    new_matrix = torch.cat((new_matrix[:, :i], new_col, new_matrix[:, i:]), dim=1)
    return new_matrix

def compute_pseudoinverse_matrix(matrix):
    n, ids = identify_null_cols(matrix)
    print(f"There are {n} columns with zeros")
    if n==0:
        return torch.linalg.inv(matrix)
    # Remove the zero columns that are causing our matrix to be singular
    new_matrix = matrix.clone()
    for id_ in ids[::-1]:
        new_matrix = remove_column(new_matrix, id_)
    # Computing inverse
    new_matrix = torch.linalg.inv(new_matrix)
    # Rescaling the matrix
    for id_ in ids:
        new_matrix = add_zero_column(new_matrix,id_)
    return new_matrix

and then change the lines 196-199 to:

matrix = hparams.mom2_update_weight * cov.double().detach().cpu()+layer_ks.detach().cpu()@layer_ks.T.detach().cpu()
n_nul_cols,_ = identify_null_cols(matrix)
if n_nul_cols != 0:
    adj_k = compute_pseudoinverse_matrix(matrix) @ layer_ks.detach().cpu()
else:
    adj_k = torch.linalg.solve(matrix,layer_ks.detach().cpu())
  1. Extra possible solution? Increment number of edits.

I hope it helps. Good luck!

from memit.

mumuyeye avatar mumuyeye commented on July 17, 2024

I found your response to be quite valuable. Thank you very much!

from memit.

mumuyeye avatar mumuyeye commented on July 17, 2024

from memit.

Related Issues (17)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.