panda0406 / zero-shot-knowledge-graph-relational-learning Goto Github PK
View Code? Open in Web Editor NEWGenerative Adversarial Zero-Shot Relational Learning for Knowledge Graphs
Generative Adversarial Zero-Shot Relational Learning for Knowledge Graphs
Why is the need to calculate real loss function? The optimizer only updates the generator parameters. The real loss is calculated using extractor and discriminator. It seems like that loss doesn't contribute any gradients.
Hey, I encountered some problems in the process of reproducing your paper experiments. It‘s warning "Out of Memory“ when I use the NELL dataset to running such loop ”for relname in self.train_tasks.keys():" in file "trainer. py". So I want to understand the GPUs use in your experiments. Please, It's important for me.
Again, this line is not compatible with the wikidataset.
Hello, I was wondering how did you find the description of the relationship in the wiki's dataset, which I can't find on the wiki's website
Hi, Pengda, thanks a lot for sharing the wonderful code. Hope to see the Wiki's corresponding code sooner.
Hi, Pengda. I am very very interested in your work, it's really a good job!
I tried to run your code recently, but I found it can't work well on Wiki-ZS dataset. I guess it may because the selection of extractor pretrain times. You have provided the pretrain times for NELL (16000), but I think 16000 may not be the best choice for Wiki. Could you provide your pretrain times for Wiki?
Sincerely hope for your reply.
Hi, Pengda. Thanks for sharing the code.
After downloading the given dataset, I found that the Complex embedding matrix in the directory Embed_used misses real or imaginary part. More details are shown below.
ComplEx.npz only has a matrix with the shape (69127, 100).
Both TransE.npz and DistMult.npz have embedding matrix with the shape (65567, 100).
According the implementation details, ComplEx should be (69127, 200).
I'm unable to download your datasets from baidu. Can you make the links accessible for people outside china?
Hi, Pengda. I am very very interested in your work, it's really a good job!
I tried to run your code recently, but I didn't find the relationship in the neighbor_encoder function that was there, is somewhere else? Here is the code for the feature extractor in Network.py
Sincerely hope for your reply.
class Extractor(nn.Module):
"""
Matching metric based on KB Embeddings
"""
def init(self, embed_dim, num_symbols, embed=None):
super(Extractor, self).init()
self.embed_dim = int(embed_dim)
self.pad_idx = num_symbols
self.symbol_emb = nn.Embedding(num_symbols + 1, embed_dim, padding_idx=num_symbols)
self.num_symbols = num_symbols
self.gcn_w = nn.Linear(self.embed_dim, int(self.embed_dim/2))
self.gcn_b = nn.Parameter(torch.FloatTensor(self.embed_dim))
self.fc1 = nn.Linear(self.embed_dim, int(self.embed_dim/2))
self.fc2 = nn.Linear(self.embed_dim, int(self.embed_dim/2))
self.dropout = nn.Dropout(0.2)
self.dropout_e = nn.Dropout(0.2)
self.symbol_emb.weight.data.copy_(torch.from_numpy(embed))
self.symbol_emb.weight.requires_grad = False
d_model = self.embed_dim * 2
self.support_encoder = SupportEncoder(d_model, 2*d_model, dropout=0.2)
#self.query_encoder = QueryEncoder(d_model, process_steps)
def neighbor_encoder(self, connections, num_neighbors):
'''
connections: (batch, 200, 2)
num_neighbors: (batch,)
'''
num_neighbors = num_neighbors.unsqueeze(1)
entities = connections[:,:,1].squeeze(-1)
ent_embeds = self.dropout(self.symbol_emb(entities)) # (batch, 50, embed_dim)
concat_embeds = ent_embeds
out = self.gcn_w(concat_embeds)
out = torch.sum(out, dim=1) # (batch, embed_dim)
out = out / num_neighbors
return out.tanh()
def entity_encoder(self, entity1, entity2):
entity1 = self.dropout_e(entity1)
entity2 = self.dropout_e(entity2)
entity1 = self.fc1(entity1)
entity2 = self.fc2(entity2)
entity = torch.cat((entity1, entity2), dim=-1)
return entity.tanh() # (batch, embed_dim)
def forward(self, query, support, query_meta=None, support_meta=None):
'''
query: (batch_size, 2)
support: (few, 2)
return: (batch_size, )
'''
query_left_connections, query_left_degrees, query_right_connections, query_right_degrees = query_meta
support_left_connections, support_left_degrees, support_right_connections, support_right_degrees = support_meta
query_e1 = self.symbol_emb(query[:,0]) # (batch, embed_dim)
query_e2 = self.symbol_emb(query[:,1]) # (batch, embed_dim)
query_e = self.entity_encoder(query_e1, query_e2)
support_e1 = self.symbol_emb(support[:,0]) # (batch, embed_dim)
support_e2 = self.symbol_emb(support[:,1]) # (batch, embed_dim)
support_e = self.entity_encoder(support_e1, support_e2)
query_left = self.neighbor_encoder(query_left_connections, query_left_degrees)
query_right = self.neighbor_encoder(query_right_connections, query_right_degrees)
support_left = self.neighbor_encoder(support_left_connections, support_left_degrees)
support_right = self.neighbor_encoder(support_right_connections, support_right_degrees)
query_neighbor = torch.cat((query_left, query_e, query_right), dim=-1) # tanh
support_neighbor = torch.cat((support_left, support_e, support_right), dim=-1) # tanh
support = support_neighbor
query = query_neighbor
support_g = self.support_encoder(support) # 1 * 100
query_g = self.support_encoder(query)
support_g = torch.mean(support_g, dim=0, keepdim=True)
# cosine similarity
matching_scores = torch.matmul(query_g, support_g.t()).squeeze()
return query_g, matching_scores
Hi,
How did you find the NELL relation discription? Can you provide the source link?
Best Regards~
Chen
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.