Git Product home page Git Product logo

Comments (9)

awni avatar awni commented on July 28, 2024

Actually, that's not true. LayerNorm only and always normalizes normalizes over the last axis.

import mlx.nn as nn
import mlx.core as mx

ln = nn.LayerNorm(32)
x = mx.random.uniform(shape=(10, 32))
print(ln(x).sum(axis=-1)) # close to 0
print(ln(x).sum(axis=0)) # not close to 0

Since MLX NN standardizes on the feature dimension being last, we don't have plans to include an axis parameter in our LayerNorm. It sounds like the current behavior works for what you want since it is consistent with Keras?

If there is something more here, let me know and we can reopen/discuss further.

from mlx.

thegodone avatar thegodone commented on July 28, 2024

But If I use the weights and bias from Keras in LayerNorm I don't get the same result , why ?

If I comment remove LayerNorm on both models assert is working.

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Embedding, LayerNormalization
import numpy as np
import mlx.nn as nn
from mlx.utils import tree_flatten
import mlx.core as mx

# Define the model
model = Sequential([
    Embedding(input_dim=40,  output_dim=32, input_length=20,name="Embedding"),
    LayerNormalization(name='LN1'),

])

# Here input_dim is the number of input features, and output_dim is the number of output features.
model.compile(
    optimizer='adam',  # Optimizer
    loss='mse',  # Mean Squared Error for regression tasks
    metrics=['mae']  # Mean Absolute Error for regression metrics
)
model.summary()

def exposeweights(model):
    w = {}
    j =0
    for layer in model.layers:
        weights = layer.get_weights()  # returns a list of all weight tensors in the layer
        print(layer.name)
        for i, weight in enumerate(weights):
            # is the Dense / linear are opposite array (ie Transpose) ?
            if layer.name+"."+str(i) in ['Output.0','Proj.0','TimeDistributed.0'] :
                w[layer.name+"."+str(i)]=weight.T
            else:
                w[layer.name+"."+str(i)]=weight
            j+=1
    return w

w = exposeweights(model)
np.savez('w.npz', **w)

class mlxduplicate(nn.Module):
    def __init__(
        self):
        super().__init__()
        self.Embedding = nn.Embedding(num_embeddings=40, dims=32)
        self.LN1 = nn.LayerNorm(32)

    def __call__(self, x):
        x = self.Embedding (x)
        x = self.LN1(x)
        return x 


model_mlx = mlxduplicate()
we = 0
for k, x in tree_flatten(model_mlx.parameters()):
    we+=x.size
    print(x.size,k)
print(we)

tensor_loaded = np.load('w.npz')

def replace_key(key: str) -> str:
    key = key.replace("Embedding.0", "Embedding.weight")
    key = key.replace("LN1.0", "LN1.weight")
    key = key.replace("LN1.1", "LN1.bias")
    return key

# switch layer names of saved keras tensors
tensors_mlx = {
    replace_key(key): tensor for key, tensor in tensor_loaded.items()
}

for k,v in tensor_loaded.items():
    print(k,v.shape)

for k,v in tensors_mlx.items():
    print(k,v.shape)

np.savez('w_convert_to_mlx.npz', **tensors_mlx)

model_mlx.load_weights('w_convert_to_mlx.npz')

x_train = np.random.randint(0,39, (2,20))
keras_output = model.predict(x_train)
keras_output.shape

mlx_output = model_mlx(mx.array(x_train))
assert mlx_output.shape == keras_output.shape


assert np.max(np.abs(mlx_output-mx.array(keras_output))) < 1e-6
``` 

2024-04-28 20:12:54.329876: I metal_plugin[/src/device/metal_device.cc:1154](http://localhost:8888/src/device/metal_device.cc#line=1153)] Metal device set to: Apple M3 Max
2024-04-28 20:12:54.329898: I metal_plugin[/src/device/metal_device.cc:296](http://localhost:8888/src/device/metal_device.cc#line=295)] systemMemory: 128.00 GB
2024-04-28 20:12:54.329900: I metal_plugin[/src/device/metal_device.cc:313](http://localhost:8888/src/device/metal_device.cc#line=312)] maxCacheSize: 48.00 GB
2024-04-28 20:12:54.329932: I tensorflow[/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306](http://localhost:8888/core/common_runtime/pluggable_device/pluggable_device_factory.cc#line=305)] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-04-28 20:12:54.329949: I tensorflow[/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272](http://localhost:8888/core/common_runtime/pluggable_device/pluggable_device_factory.cc#line=271)] Created TensorFlow device ([/job](http://localhost:8888/job):localhost[/replica:0](http://localhost:8888/replica#line=-1)[/task:0](http://localhost:8888/task#line=-1)[/device](http://localhost:8888/device):GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 Embedding (Embedding)       (None, 20, 32)            1280      
                                                                 
 LN1 (LayerNormalization)    (None, 20, 32)            64        
                                                                 
=================================================================
Total params: 1344 (5.25 KB)
Trainable params: 1344 (5.25 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
Embedding
LN1
1280 Embedding.weight
32 LN1.bias
32 LN1.weight
1344
Embedding.0 (40, 32)
LN1.0 (32,)
LN1.1 (32,)
Embedding.weight (40, 32)
LN1.weight (32,)
LN1.bias (32,)
2024-04-28 20:12:54.785211: I tensorflow[/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117](http://localhost:8888/core/grappler/optimizers/custom_graph_optimizer_registry.cc#line=116)] Plugin optimizer for device_type GPU is enabled.
1/1 [==============================] - 1s 832ms/step
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[2], line 93
     89 mlx_output = model_mlx(mx.array(x_train))
     90 assert mlx_output.shape == keras_output.shape
---> 93 assert np.max(np.abs(mlx_output-mx.array(keras_output))) < 1e-6

AssertionError:

from mlx.

awni avatar awni commented on July 28, 2024

It looks like Keras uses a different default (and much higher) epsilon for numerical stability. You can set this in the MLX LayerNorm constructor. The following passes:

from tensorflow.keras.layers import LayerNormalization
import numpy as np
import mlx.nn as nn
import mlx.core as mx

# Define the model
ln = LayerNormalization(name='LN1')
x = np.random.uniform(size=(10, 32))
out_keras = np.array(ln(x))

ln_mlx = nn.LayerNorm(32, eps=1e-3) # note setting epsilon here
out_mlx = np.array(ln_mlx(mx.array(x)))
assert np.abs((out_keras - out_mlx)).max() < 1e-6

from mlx.

thegodone avatar thegodone commented on July 28, 2024

thanks @awni very much appreciate your help on that: I still have one question can you explain me this error for very large dataset I have a strange behaviour ?
image
I use this code :

import math
from typing import Any
import mlx.nn as nn
from mlx.utils import tree_flatten
import numpy as np
import mlx.core as mx

from mlx.nn.layers.base import Module

class AttentionM_(Module):
    def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
        super().__init__()
        self.output_dims = output_dims
        scale = math.sqrt(1.0 / input_dims)
        self.weight = mx.random.uniform(
            low=-scale,
            high=scale,
            shape=(input_dims, 1),
        )
        if bias:
            self.bias = mx.zeros(shape=(output_dims, 1))

    def _extra_repr(self) -> str:
        return f"input_dims={self.weight.shape[0]}, output_dims={self.output_dims}, bias={'bias' in self}"

    def __call__(self, x: mx.array) -> mx.array:
        if "bias" in self:
            x_ = mx.addmm(self["bias"], x, self["weight"])
        else:
            x_ = x @ self["weight"]
        x_ = mx.tanh(x_)
        x_ = mx.expand_dims(mx.softmax(mx.squeeze(x_,axis=-1), axis=-1),axis=-1)
        x = mx.sum(x*x_,axis=1)
        return x
        
class TimeDistributed_(nn.Module):
    def __init__(
        self,
        func : nn.Module
    ):
        super().__init__()
        self.func = func

    def __call__(self, x):
        b_, t_ = x.shape[:2]
        c_ = self.func(x.flatten(0,1))
        return c_.reshape(b_, t_, *c_.shape[1:])
        
class Bidirectionnal_(nn.Module):
    def __init__(
        self,
        func1 : nn.Module,
        func2 : nn.Module
    ):
        super().__init__()
        self.func1 = func1
        self.func2 = func2

    def __call__(self, x):

        h_f, h_b = self.func1(x), self.func2(x[:, ::-1, :]) 
        return  mx.stack([h_f[0], h_b[0]], axis=-1).flatten(-2,-1)
    


class SmilesX(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        inputdim: int,
        embdim: int,
        lstmdim: int,
        densedim1: int,
        densedim2: int,
        checkpoint: bool,
        debug: bool,
    ):
        super().__init__()

        self.Embedding = nn.Embedding(num_embeddings=vocab_size, dims=embdim)
        
        self.Image = Bidirectionnal_(nn.LSTM(embdim, lstmdim, bias=True),
                                    nn.LSTM(embdim, lstmdim, bias=True))
        self.TimeDistributed = TimeDistributed_(nn.Linear(2*lstmdim,densedim1))
        self.AttentionM  = AttentionM_(densedim1,inputdim, bias=True)
        self.Layernorm1 =  nn.LayerNorm(densedim1,eps=0.001)
        self.Proj = nn.Linear(densedim1,densedim2)
        self.Layernorm2 =  nn.LayerNorm(densedim2,eps=0.001)
        self.Output = nn.Linear(densedim2,1)
        self.lk = nn.LeakyReLU(0.1)
        self.debug = debug
    
    def __call__(self, x):
        if self.debug:
            print('Input:',x.shape)

        # embedding
        x = self.Embedding(x)
        if self.debug:

            print('Embedding:',x.shape)
        # Bidirectional 
        x = self.Image(x)
        if self.debug:
            print('BiLSTM:',x.shape)

        #  TimeDistributed 
        x = self.TimeDistributed(x)
        if self.debug:
            print('TimeDistributed:',x.shape)
       # self attention
        x = self.AttentionM(x)
        if self.debug:
            print('AttentionM:',x.shape)

        # Layer norm 
        x = self.Layernorm1(x)
        if self.debug:
            print('LayerNorm 1:',x.shape)        
        x = self.Proj(x)        
        if self.debug:
            print('proj:',x.shape)

        x = self.lk(x)
        # Layer norm 
        x = self.Layernorm2(x)
        if self.debug:
            print('LayerNorm 2:',x.shape)

        x = self.Output(x)
        if self.debug:
            print('Output:',x.shape)
        return x


model = SmilesX(vocab_size=42, 
                inputdim = 128,
                embdim = 32,
                lstmdim =  32,
                densedim1 = 64,
                densedim2 = 64,
                checkpoint=False,
                debug=False)



# Initialize model:
nparams = sum(
    x.size for k, x in tree_flatten(model.parameters()))
print(f"Training a SMILES-X Model with {nparams}  parameters")

xt = 0
for k, x in tree_flatten(model.parameters()):
    print(x.size,k)
    xt+=x.size
print(xt)

# test for big dataset:
X = mx.random.randint(0,42,[600000,128])

# evaluate by data size the results
Y1 = model(X[:10,:])
Y3 = model(X[:100,:])
Y2 = model(X[:200,:])
Y4 = model(X[:1000,:])
Y5 = model(X[:10000,:])
Y6 = model(X[:100000,:])
Y7 = model(X[:600000,:])

# validate the results 
assert np.max(np.abs(Y1 - Y3[:10]))<1e-8
assert np.max(np.abs(Y1 - Y2[:10]))<1e-8
assert np.max(np.abs(Y1 - Y4[:10]))<1e-8
assert np.max(np.abs(Y1 - Y5[:10]))<1e-8
assert np.max(np.abs(Y1 - Y6[:10]))<1e-8
assert np.max(np.abs(Y1 - Y7[:10]))<1e-8

from mlx.

awni avatar awni commented on July 28, 2024

The size is really large, I think some matrices are well over 4B entries. My guess is it's overflowing an integer index somewhere but I'm not sure where. I'll look into where that is to see if we can put a error message or fix it. For now I would stick to smaller sizes.

from mlx.

awni avatar awni commented on July 28, 2024

I filed a separate issue about this #1051

from mlx.

thegodone avatar thegodone commented on July 28, 2024

from mlx.

awni avatar awni commented on July 28, 2024

Does it happens during evaluation too ? Would be nice to add batch size for inference

The problem is the large matmul in the LSTM. So if the batch size is larger than about 131k (for the LSTM / model dimensions you provided) then it will break regardless of inference / training modes.

from mlx.

thegodone avatar thegodone commented on July 28, 2024

from mlx.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.