Git Product home page Git Product logo

cvqluu / nn-similarity-diarization Goto Github PK

View Code? Open in Web Editor NEW
41.0 2.0 12.0 355 KB

Neural network based similarity scoring for diarization (pytorch implementation of "LSTM based Similarity Measurement with Spectral Clustering for Speaker Diarization")

License: MIT License

Python 94.15% Shell 5.85%
pytorch diarization neural-network speech similarity-score similarity kaldi lstm speaker-recognition speaker-diarization

nn-similarity-diarization's People

Contributors

cvqluu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

nn-similarity-diarization's Issues

Why does LSTM model not have sigmoid fuction?

class LSTMSimilarity(nn.Module):

def __init__(self, input_size=256, hidden_size=256, num_layers=2):
    super(LSTMSimilarity, self).__init__()
    self.lstm = nn.LSTM(input_size,
                        hidden_size,
                        num_layers=num_layers,
                        bidirectional=True,
                        batch_first=True)
    self.fc1 = nn.Linear(hidden_size*2, 64)
    self.nl = nn.ReLU(inplace=True)
    self.fc2 = nn.Linear(64, 1)

def forward(self, x):
    self.lstm.flatten_parameters()
    x, _ = self.lstm(x)
    x = self.fc1(x)
    x = self.nl(x)
    x = self.fc2(x).squeeze(2)
    return x

this is your LSTM model, but in original paper, it need a sigmoid fuction before output, like this:

x = torch.sigmoid(self.fc2(x).squeeze(2))
return x
but when I change this, totally error change from 9.7% to 37%, I don't know why

The paper used the label of center 750ms as label of the whole 1.5s segment, but it seems that in your code you use the whole 1.5s segment to calculate the label.

def segment_labels(segments, rttm, xvectorscp, xvecbase_path=None):
segment_cols = load_n_col(segments, numpy=True)
segment_rows = np.array(list(zip(*segment_cols))) #解压
rttm_cols = load_n_col(rttm, numpy=True)
vec_utts, vec_paths = load_n_col(xvectorscp, numpy=True)
if not xvecbase_path:
xvecbase_path = os.path.dirname(xvectorscp)
assert sum(vec_utts == segment_cols[0]) == len(segment_cols[0])
vec_paths = change_base_paths(vec_paths, new_base_path=xvecbase_path)

rttm_cols.append(rttm_cols[3].astype(float) + rttm_cols[4].astype(float))#起始时间+duration=结束时间
recording_ids = sorted(set(segment_cols[1])) # recording_id 如iaaa
events0 = np.array(segment_cols[2:4]).astype(float).transpose() #segment起止时间
events1 = np.vstack([rttm_cols[3].astype(float), rttm_cols[-1]]).transpose() #ref起止时间(groundtruth)

rec_batches = []

for rec_id in tqdm(recording_ids): #tqdm是进度条
    seg_indexes = segment_cols[1] == rec_id  #recording_id iaaa ==优先级高
    rttm_indexes = rttm_cols[1] == rec_id
    ev0 = events0[seg_indexes]   #rec_id对应的segment里每段的起止时间
    ev1 = events1[rttm_indexes]  #rec_id对应的rffm里每段音频的起止时间
    ev1_labels = rttm_cols[7][rttm_indexes] #rec_id对应的rffm里每段音频的speaker
    ev0_labels = assign_overlaps(evnew, ev1, ev1_labels)
    ev0_labels = ['{}_{}'.format(rec_id, l) for l in ev0_labels]  #形成speaker_id,如iaaa_A
    batch = (segment_cols[0][seg_indexes], ev0_labels, vec_paths[seg_indexes], segment_rows[seg_indexes])
    rec_batches.append(batch)

return recording_ids, rec_batches

How can datasets be obtained for free?

Hello:

I'm glad you can publish the code! But it's a headache to have no data set. I want to ask how to obtain the corresponding data set? Because the data set in the field of speaker segmentation and clustering is really a headache. Without a data set, many experiments can not be done.

If you can put forward some suggestions, it would be very grateful!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.