Comments (1)
- from Kosmos
def forward_embedding(
self,
tokens,
token_embedding=None,
incremental_state=None,
):
if incremental_state is not None and not self.is_first_step(incremental_state):
tokens = tokens[:, -1:]
if token_embedding is None:
token_embedding = self.embed_tokens(tokens)
x = embed = self.embed_scale * token_embedding
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x, embed
def is_first_step(self, incremental_state):
if incremental_state is None:
return False
return incremental_state.get("is_first_step", False)
def forward(
self,
prev_output_tokens,
incremental_state=None,
features_only=False,
return_all_hiddens=False,
token_embeddings=None,
**kwargs
):
# embed tokens
x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state
)
is_first_step = self.is_first_step(incremental_state)
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
padding_len = self.recurrent_chunk_size - prev_output_tokens.size(1) % self.recurrent_chunk_size
slen = prev_output_tokens.size(1) + padding_len
x = F.pad(x, (0, 0, 0, padding_len))
else:
slen = prev_output_tokens.size(1)
# relative position
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
# decoder layers
inner_states = [x]
l_aux = []
for idx, layer in enumerate(self.layers):
if incremental_state is None or is_first_step:
if is_first_step and incremental_state is not None:
if idx not in incremental_state:
incremental_state[idx] = {}
else:
if idx not in incremental_state:
incremental_state[idx] = {}
x, l_aux_i = layer(
x,
incremental_state[idx] if incremental_state is not None else None,
retention_rel_pos=retention_rel_pos,
chunkwise_recurrent=self.chunkwise_recurrent,
)
l_aux.append(l_aux_i)
inner_states.append(x)
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
x = x[:, :prev_output_tokens.size(1), :]
if self.layer_norm is not None:
x = self.layer_norm(x)
if not features_only:
x = self.output_layer(x)
return x, {
"inner_states": inner_states,
"l_aux": l_aux,
"attn": None,
}
def output_layer(self, features):
return self.output_projection(features)
from palm-e.
Related Issues (12)
- Unpack output
- Training SOP Link not found
- GPU size for training?
- how to transform the output to the planing?
- Fixing "from palme.model import PALME" in train.py HOT 3
- Problem with example.py HOT 5
- About training HOT 1
- How to extract input embeddings for text
- How do you know if it's right without training? HOT 1
- Can PaLM-E be used on any robot? HOT 1
- more examples HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from palm-e.