Comments (9)
Actually, that's not true. LayerNorm
only and always normalizes normalizes over the last axis.
import mlx.nn as nn
import mlx.core as mx
ln = nn.LayerNorm(32)
x = mx.random.uniform(shape=(10, 32))
print(ln(x).sum(axis=-1)) # close to 0
print(ln(x).sum(axis=0)) # not close to 0
Since MLX NN standardizes on the feature dimension being last, we don't have plans to include an axis parameter in our LayerNorm. It sounds like the current behavior works for what you want since it is consistent with Keras?
If there is something more here, let me know and we can reopen/discuss further.
from mlx.
But If I use the weights and bias from Keras in LayerNorm I don't get the same result , why ?
If I comment remove LayerNorm on both models assert is working.
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Embedding, LayerNormalization
import numpy as np
import mlx.nn as nn
from mlx.utils import tree_flatten
import mlx.core as mx
# Define the model
model = Sequential([
Embedding(input_dim=40, output_dim=32, input_length=20,name="Embedding"),
LayerNormalization(name='LN1'),
])
# Here input_dim is the number of input features, and output_dim is the number of output features.
model.compile(
optimizer='adam', # Optimizer
loss='mse', # Mean Squared Error for regression tasks
metrics=['mae'] # Mean Absolute Error for regression metrics
)
model.summary()
def exposeweights(model):
w = {}
j =0
for layer in model.layers:
weights = layer.get_weights() # returns a list of all weight tensors in the layer
print(layer.name)
for i, weight in enumerate(weights):
# is the Dense / linear are opposite array (ie Transpose) ?
if layer.name+"."+str(i) in ['Output.0','Proj.0','TimeDistributed.0'] :
w[layer.name+"."+str(i)]=weight.T
else:
w[layer.name+"."+str(i)]=weight
j+=1
return w
w = exposeweights(model)
np.savez('w.npz', **w)
class mlxduplicate(nn.Module):
def __init__(
self):
super().__init__()
self.Embedding = nn.Embedding(num_embeddings=40, dims=32)
self.LN1 = nn.LayerNorm(32)
def __call__(self, x):
x = self.Embedding (x)
x = self.LN1(x)
return x
model_mlx = mlxduplicate()
we = 0
for k, x in tree_flatten(model_mlx.parameters()):
we+=x.size
print(x.size,k)
print(we)
tensor_loaded = np.load('w.npz')
def replace_key(key: str) -> str:
key = key.replace("Embedding.0", "Embedding.weight")
key = key.replace("LN1.0", "LN1.weight")
key = key.replace("LN1.1", "LN1.bias")
return key
# switch layer names of saved keras tensors
tensors_mlx = {
replace_key(key): tensor for key, tensor in tensor_loaded.items()
}
for k,v in tensor_loaded.items():
print(k,v.shape)
for k,v in tensors_mlx.items():
print(k,v.shape)
np.savez('w_convert_to_mlx.npz', **tensors_mlx)
model_mlx.load_weights('w_convert_to_mlx.npz')
x_train = np.random.randint(0,39, (2,20))
keras_output = model.predict(x_train)
keras_output.shape
mlx_output = model_mlx(mx.array(x_train))
assert mlx_output.shape == keras_output.shape
assert np.max(np.abs(mlx_output-mx.array(keras_output))) < 1e-6
```
2024-04-28 20:12:54.329876: I metal_plugin[/src/device/metal_device.cc:1154](http://localhost:8888/src/device/metal_device.cc#line=1153)] Metal device set to: Apple M3 Max
2024-04-28 20:12:54.329898: I metal_plugin[/src/device/metal_device.cc:296](http://localhost:8888/src/device/metal_device.cc#line=295)] systemMemory: 128.00 GB
2024-04-28 20:12:54.329900: I metal_plugin[/src/device/metal_device.cc:313](http://localhost:8888/src/device/metal_device.cc#line=312)] maxCacheSize: 48.00 GB
2024-04-28 20:12:54.329932: I tensorflow[/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306](http://localhost:8888/core/common_runtime/pluggable_device/pluggable_device_factory.cc#line=305)] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-04-28 20:12:54.329949: I tensorflow[/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272](http://localhost:8888/core/common_runtime/pluggable_device/pluggable_device_factory.cc#line=271)] Created TensorFlow device ([/job](http://localhost:8888/job):localhost[/replica:0](http://localhost:8888/replica#line=-1)[/task:0](http://localhost:8888/task#line=-1)[/device](http://localhost:8888/device):GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
Embedding (Embedding) (None, 20, 32) 1280
LN1 (LayerNormalization) (None, 20, 32) 64
=================================================================
Total params: 1344 (5.25 KB)
Trainable params: 1344 (5.25 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
Embedding
LN1
1280 Embedding.weight
32 LN1.bias
32 LN1.weight
1344
Embedding.0 (40, 32)
LN1.0 (32,)
LN1.1 (32,)
Embedding.weight (40, 32)
LN1.weight (32,)
LN1.bias (32,)
2024-04-28 20:12:54.785211: I tensorflow[/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117](http://localhost:8888/core/grappler/optimizers/custom_graph_optimizer_registry.cc#line=116)] Plugin optimizer for device_type GPU is enabled.
1/1 [==============================] - 1s 832ms/step
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[2], line 93
89 mlx_output = model_mlx(mx.array(x_train))
90 assert mlx_output.shape == keras_output.shape
---> 93 assert np.max(np.abs(mlx_output-mx.array(keras_output))) < 1e-6
AssertionError:
from mlx.
It looks like Keras uses a different default (and much higher) epsilon
for numerical stability. You can set this in the MLX LayerNorm
constructor. The following passes:
from tensorflow.keras.layers import LayerNormalization
import numpy as np
import mlx.nn as nn
import mlx.core as mx
# Define the model
ln = LayerNormalization(name='LN1')
x = np.random.uniform(size=(10, 32))
out_keras = np.array(ln(x))
ln_mlx = nn.LayerNorm(32, eps=1e-3) # note setting epsilon here
out_mlx = np.array(ln_mlx(mx.array(x)))
assert np.abs((out_keras - out_mlx)).max() < 1e-6
from mlx.
thanks @awni very much appreciate your help on that: I still have one question can you explain me this error for very large dataset I have a strange behaviour ?
I use this code :
import math
from typing import Any
import mlx.nn as nn
from mlx.utils import tree_flatten
import numpy as np
import mlx.core as mx
from mlx.nn.layers.base import Module
class AttentionM_(Module):
def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
super().__init__()
self.output_dims = output_dims
scale = math.sqrt(1.0 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, 1),
)
if bias:
self.bias = mx.zeros(shape=(output_dims, 1))
def _extra_repr(self) -> str:
return f"input_dims={self.weight.shape[0]}, output_dims={self.output_dims}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
if "bias" in self:
x_ = mx.addmm(self["bias"], x, self["weight"])
else:
x_ = x @ self["weight"]
x_ = mx.tanh(x_)
x_ = mx.expand_dims(mx.softmax(mx.squeeze(x_,axis=-1), axis=-1),axis=-1)
x = mx.sum(x*x_,axis=1)
return x
class TimeDistributed_(nn.Module):
def __init__(
self,
func : nn.Module
):
super().__init__()
self.func = func
def __call__(self, x):
b_, t_ = x.shape[:2]
c_ = self.func(x.flatten(0,1))
return c_.reshape(b_, t_, *c_.shape[1:])
class Bidirectionnal_(nn.Module):
def __init__(
self,
func1 : nn.Module,
func2 : nn.Module
):
super().__init__()
self.func1 = func1
self.func2 = func2
def __call__(self, x):
h_f, h_b = self.func1(x), self.func2(x[:, ::-1, :])
return mx.stack([h_f[0], h_b[0]], axis=-1).flatten(-2,-1)
class SmilesX(nn.Module):
def __init__(
self,
vocab_size: int,
inputdim: int,
embdim: int,
lstmdim: int,
densedim1: int,
densedim2: int,
checkpoint: bool,
debug: bool,
):
super().__init__()
self.Embedding = nn.Embedding(num_embeddings=vocab_size, dims=embdim)
self.Image = Bidirectionnal_(nn.LSTM(embdim, lstmdim, bias=True),
nn.LSTM(embdim, lstmdim, bias=True))
self.TimeDistributed = TimeDistributed_(nn.Linear(2*lstmdim,densedim1))
self.AttentionM = AttentionM_(densedim1,inputdim, bias=True)
self.Layernorm1 = nn.LayerNorm(densedim1,eps=0.001)
self.Proj = nn.Linear(densedim1,densedim2)
self.Layernorm2 = nn.LayerNorm(densedim2,eps=0.001)
self.Output = nn.Linear(densedim2,1)
self.lk = nn.LeakyReLU(0.1)
self.debug = debug
def __call__(self, x):
if self.debug:
print('Input:',x.shape)
# embedding
x = self.Embedding(x)
if self.debug:
print('Embedding:',x.shape)
# Bidirectional
x = self.Image(x)
if self.debug:
print('BiLSTM:',x.shape)
# TimeDistributed
x = self.TimeDistributed(x)
if self.debug:
print('TimeDistributed:',x.shape)
# self attention
x = self.AttentionM(x)
if self.debug:
print('AttentionM:',x.shape)
# Layer norm
x = self.Layernorm1(x)
if self.debug:
print('LayerNorm 1:',x.shape)
x = self.Proj(x)
if self.debug:
print('proj:',x.shape)
x = self.lk(x)
# Layer norm
x = self.Layernorm2(x)
if self.debug:
print('LayerNorm 2:',x.shape)
x = self.Output(x)
if self.debug:
print('Output:',x.shape)
return x
model = SmilesX(vocab_size=42,
inputdim = 128,
embdim = 32,
lstmdim = 32,
densedim1 = 64,
densedim2 = 64,
checkpoint=False,
debug=False)
# Initialize model:
nparams = sum(
x.size for k, x in tree_flatten(model.parameters()))
print(f"Training a SMILES-X Model with {nparams} parameters")
xt = 0
for k, x in tree_flatten(model.parameters()):
print(x.size,k)
xt+=x.size
print(xt)
# test for big dataset:
X = mx.random.randint(0,42,[600000,128])
# evaluate by data size the results
Y1 = model(X[:10,:])
Y3 = model(X[:100,:])
Y2 = model(X[:200,:])
Y4 = model(X[:1000,:])
Y5 = model(X[:10000,:])
Y6 = model(X[:100000,:])
Y7 = model(X[:600000,:])
# validate the results
assert np.max(np.abs(Y1 - Y3[:10]))<1e-8
assert np.max(np.abs(Y1 - Y2[:10]))<1e-8
assert np.max(np.abs(Y1 - Y4[:10]))<1e-8
assert np.max(np.abs(Y1 - Y5[:10]))<1e-8
assert np.max(np.abs(Y1 - Y6[:10]))<1e-8
assert np.max(np.abs(Y1 - Y7[:10]))<1e-8
from mlx.
The size is really large, I think some matrices are well over 4B entries. My guess is it's overflowing an integer index somewhere but I'm not sure where. I'll look into where that is to see if we can put a error message or fix it. For now I would stick to smaller sizes.
from mlx.
I filed a separate issue about this #1051
from mlx.
from mlx.
Does it happens during evaluation too ? Would be nice to add batch size for inference
The problem is the large matmul in the LSTM. So if the batch size is larger than about 131k (for the LSTM / model dimensions you provided) then it will break regardless of inference / training modes.
from mlx.
from mlx.
Related Issues (20)
- [BUG] mx.radians & mx.degrees - unexpected behavior when the input is not an array
- [Feature] Build flag to make safetensor and GGUF dependencies optional
- [Feature] something like `mlx.scipy.stats` HOT 2
- [BUG] Passing `axis=None` into `argpartition` causes `TypeError` HOT 1
- [BUG] AttributeError in mlx.core.conj and mlx.core.conjugate functions HOT 3
- Optimization Plans for Conv2D CPU Execution HOT 3
- [BUG] mlx gets stuck with high-dimensional array on Linux HOT 1
- Implement trace analogical to numpy.trace HOT 1
- [BUG] Wrong slice of a 4D array assigned to with GPU HOT 1
- problem of using mlx package HOT 2
- Question about supporting slices of the type a[:, [0]] HOT 2
- Difference in training convergence between PyTorch & MLX HOT 2
- [BUG] mlx.core.topk throws segmentation fault for large dimension HOT 1
- [BUG] JIT compile mode does not work with LoRA
- [Feature] dlpack device HOT 5
- [BUG] Compiled mx.eval(model.state) raises “Attempting to eval an array without a primitive” with mlx.optimizers.Adam HOT 4
- [BUG] compile + checkpoint segfaults HOT 1
- [BUG] Wrong result for sliced matmul on GPU HOT 1
- I'm asking for help with the following error: HOT 1
- [Enhancement] be able to override MLX_METAL_VERSION when running cmake
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mlx.