Git Product home page Git Product logo

Comments (3)

awni avatar awni commented on September 2, 2024 1

So in general when you run stuff on the device you have to setup a bunch of state. This includes loading the GPU device itself, loading the metal library, setting up the allocator, possibly some other stuff. All of that takes time. And it usually happens the first time you evaluate an operation in MLX.

Its good practice in general to time with a warmup like so:

# warmup
for _ in range(5):
  myfun()

tic = time.perf_counter()
for _ in range(10):
  myfun()
toc = time.perf_counter()
time_per_it = (toc - tic) / 10

Regarding the cache:

would the subsequent mx.eval reuse the previous results?

No that never happens

If yes, is mx.metal.clear_cache() the right way to avoid the reuse of cache?

The cache is a memory buffer cache but we don't reuse results from it, just data buffers. I would not clear it unless you are concerned about memory use because clearing it will slow things down.

from mlx.

wangkuiyi avatar wangkuiyi commented on September 2, 2024

I have the same question. I made each run handles different tensors (even though mx.metal.clear_cache() is called for each run). I also add another benchmark of matmul.

import time

import mlx.core as mx


def benchmark_array_creation(N: int):
    mx.metal.clear_cache()
    toi = time.perf_counter()
    mx.eval(mx.ones((1, N)))
    toc = time.perf_counter()
    tpi = 1e3 * (toc - toi)
    print(f"array creation takes {tpi:.3f} ms")


def benchmark_matmul(N: int):
    mx.metal.clear_cache()
    toi = time.perf_counter()
    mx.eval(mx.ones((N, N)) * mx.ones((N, N)) * N)
    toc = time.perf_counter()
    tpi = 1e3 * (toc - toi)
    print(f"array creation takes {tpi:.3f} ms")


if __name__ == "__main__":
    print(mx.default_device())
    for i in range(10):
        print(f"array creation run {i}", end=" ")
        benchmark_array_creation(1024 + i)

    for i in range(10):
        print(f"matmul run {i}", end=" ")
        benchmark_matmul(1024 + i)

The result is as follows. The first run of matmul is not much slower than the rest.

array creation run 0 array creation takes 10.809 ms
array creation run 1 array creation takes 0.424 ms
array creation run 2 array creation takes 0.843 ms
array creation run 3 array creation takes 0.301 ms
array creation run 4 array creation takes 0.418 ms
array creation run 5 array creation takes 0.272 ms
array creation run 6 array creation takes 0.262 ms
array creation run 7 array creation takes 0.258 ms
array creation run 8 array creation takes 0.269 ms
array creation run 9 array creation takes 0.258 ms
matmul run 0 array creation takes 0.947 ms
matmul run 1 array creation takes 0.798 ms
matmul run 2 array creation takes 0.764 ms
matmul run 3 array creation takes 0.430 ms
matmul run 4 array creation takes 0.727 ms
matmul run 5 array creation takes 0.384 ms
matmul run 6 array creation takes 0.703 ms
matmul run 7 array creation takes 0.418 ms
matmul run 8 array creation takes 0.766 ms
matmul run 9 array creation takes 0.399 ms

Then, I switched the order of the two loops:

if __name__ == "__main__":
    print(mx.default_device())

    for i in range(10):
        print(f"matmul run {i}", end=" ")
        benchmark_matmul(1024 + i)

    for i in range(10):
        print(f"array creation run {i}", end=" ")
        benchmark_array_creation(1024 + i)

The first run of matmul now is the slowest run among all.

matmul run 0 array creation takes 10.599 ms
matmul run 1 array creation takes 0.821 ms
matmul run 2 array creation takes 0.813 ms
matmul run 3 array creation takes 0.417 ms
matmul run 4 array creation takes 0.719 ms
matmul run 5 array creation takes 0.437 ms
matmul run 6 array creation takes 0.733 ms
matmul run 7 array creation takes 0.607 ms
matmul run 8 array creation takes 0.764 ms
matmul run 9 array creation takes 0.442 ms
array creation run 0 array creation takes 0.284 ms
array creation run 1 array creation takes 0.256 ms
array creation run 2 array creation takes 0.298 ms
array creation run 3 array creation takes 0.297 ms
array creation run 4 array creation takes 0.334 ms
array creation run 5 array creation takes 0.262 ms
array creation run 6 array creation takes 0.297 ms
array creation run 7 array creation takes 0.267 ms
array creation run 8 array creation takes 0.278 ms
array creation run 9 array creation takes 0.285 ms

from mlx.

cccyf avatar cccyf commented on September 2, 2024

Thank you @awni for the reply! I modified my code as you suggested, and the output makes more sense now

import time

import mlx.core as mx

def benchmark_array_creation(N: int):
    for _ in range(5):
        mx.eval(mx.ones((1, N)))
    toi = time.perf_counter()
    for _ in range(100):
        mx.eval(mx.ones((1, N)))
    toc = time.perf_counter()
    tpi = 1e3 * (toc - toi)
    print(f"array creation takes {tpi:.3f} ms")
    
if __name__ == "__main__":
    print(mx.default_device())
    for i in range(10):
        print(f"benchmark array creation, run {i}")
        benchmark_array_creation(1024)

Output

benchmark array creation, run 0
array creation takes 25.980 ms
benchmark array creation, run 1
array creation takes 23.795 ms
benchmark array creation, run 2
array creation takes 20.705 ms
benchmark array creation, run 3
array creation takes 18.523 ms
benchmark array creation, run 4
array creation takes 18.641 ms
benchmark array creation, run 5
array creation takes 18.616 ms
benchmark array creation, run 6
array creation takes 18.542 ms
benchmark array creation, run 7
array creation takes 20.475 ms
benchmark array creation, run 8
array creation takes 21.766 ms
benchmark array creation, run 9
array creation takes 21.793 ms

from mlx.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.