Git Product home page Git Product logo

jaxseq's People

Contributors

sea-snell avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jaxseq's Issues

Could you please explain the rationale behind including both input and output in the data format for a causal language model?

Hello,

I came across this repository while searching for ways to train a 10~20B scale causal language model (similar to GPT) using TPUv4 pod.

In the README, I noticed that the recommended data format is {'train': [{'in_text', 'out_text'}, ...], 'eval': [{'in_text', 'out_text'}, ...]}. Could you please explain the reasoning behind this format?

As far as I understand, GPT-2 or GPT-J only requires single texts as input and output for training, instead of the input-output pairs mentioned in the data format. I appreciate any clarification you can provide on this matter.

Thank you.

OOM issue training on TPUv3-8

Hi,

I'm trying to fine-tune a gpt-j-6B model on a TPU v3-8 instance. I get an OOM error despite doing model parallelism. Could you please let me know what could be wrong and what possible fixes I can use?

Here's the command I used:
python examples/gptj_train.py $expname EleutherAI/gpt-j-6B /home/qj213/data/debug.json --use-wandb --wandb-project $expname --model-p-shape 8 --data-p-shape 1 --train-bsize 1 --inference-bsize 1

And here's the error I received:

Traceback (most recent call last):
File "/home/qj213/JAXSeq/examples/gptj_train.py", line 265, in
tyro.cli(main)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/tyro/_cli.py", line 114, in cli
_cli_impl(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/tyro/_cli.py", line 293, in _cli_impl
out, consumed_keywords = _calling.call_from_args(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/tyro/_calling.py", line 192, in call_from_args
return unwrapped_f(*args, **kwargs), consumed_keywords # type: ignore
File "/home/qj213/JAXSeq/examples/gptj_train.py", line 239, in main
trainer, inference = train_loop(
File "/home/qj213/JAXSeq/src/seq2seq_train.py", line 108, in train_loop
_, info, trainer = trainer.train_step(items, new_rng)
File "/home/qj213/JAXSeq/src/core.py", line 59, in train_step
loss, info, new_params, new_optim_state = self.train_fn(self.params, self.optim_state, rng_key, batch)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/experimental/pjit.py", line 138, in wrapped
return _python_pjit_helper(infer_params, *args, **kwargs)[0]
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/experimental/pjit.py", line 130, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/core.py", line 328, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/core.py", line 331, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/core.py", line 698, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/experimental/pjit.py", line 919, in _pjit_call_impl
compiled = _pjit_lower(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 3072, in compile
self._executable = MeshExecutable.from_hlo(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 3251, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/_src/dispatch.py", line 1056, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/_src/profiler.py", line 313, in wrapper
return func(*args, **kwargs)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/_src/dispatch.py", line 996, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 17.28G of 15.48G hbm. Exceeded hbm capacity by 1.80G.

Total hbm usage >= 17.80G:
reserved 530.00M
program 5.78G
arguments 11.51G

Output size 11.51G; shares 11.51G with arguments.

Program hbm requirement 5.78G:
global 1.35M
scoped 7.52M
HLO temp 5.77G (100.0% utilization: Unpadded (5.76G) Padded (5.76G), 0.2% fragmentation (12.13M))

Largest program allocations in hbm:

  1. Size: 128.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/lm_head/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[4096,8192]{1,0:T(8,128)}
    Unpadded size: 128.00M
    XLA label: fusion.1349 = fusion(param.521, convert.1432, fusion.116, fusion.341), kind=kOutput, calls=fused_computation.1094
    Allocation type: HLO temp

  2. Size: 128.00M
    Operator: op_name="pjit(step_fn)/jit(main)/cond/branch_0_fun/mul" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/transform.py" source_line=405
    Shape: f32[4096,8192]{1,0:T(8,128)}
    Unpadded size: 128.00M
    XLA label: fusion.489 = fusion(get-tuple-element.1378, subtract.201, subtract.200, get-tuple-element.1377, ...(+2)), kind=kLoop, calls=fused_computation.234
    Allocation type: HLO temp

  3. Size: 128.00M
    Operator: op_name="pjit(step_fn)/jit(main)/cond/branch_0_fun/mul" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/transform.py" source_line=405
    Shape: f32[8192,4096]{1,0:T(8,128)}
    Unpadded size: 128.00M
    XLA label: fusion.486 = fusion(get-tuple-element.2510, subtract.201, subtract.200, get-tuple-element.2509, ...(+2)), kind=kLoop, calls=fused_computation.231
    Allocation type: HLO temp

  4. Size: 128.00M
    Operator: op_name="pjit(step_fn)/jit(main)/add" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/wrappers.py" source_line=327
    Shape: f32[8192,4096]{1,0:T(8,128)}
    Unpadded size: 128.00M
    XLA label: fusion.1348 = fusion(param.1087, fusion.229, convert.1432), kind=kLoop, calls=fused_computation.1093
    Allocation type: HLO temp

  5. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/13/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[4096,2048]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1399 = fusion(param.637, convert.1432, fusion.281, get-tuple-element.6842), kind=kOutput, calls=fused_computation.1144
    Allocation type: HLO temp

  6. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/13/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[2048,4096]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1398 = fusion(param.641, convert.1432, fusion.401, fusion.3882), kind=kOutput, calls=fused_computation.1143
    Allocation type: HLO temp

  7. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/12/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[2048,4096]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1400 = fusion(param.621, convert.1432, fusion.405, fusion.3867), kind=kOutput, calls=fused_computation.1145
    Allocation type: HLO temp

  8. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/12/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[4096,2048]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1401 = fusion(param.617, convert.1432, fusion.277, get-tuple-element.6840), kind=kOutput, calls=fused_computation.1146
    Allocation type: HLO temp

  9. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/11/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[2048,4096]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1402 = fusion(param.601, convert.1432, fusion.409, fusion.3852), kind=kOutput, calls=fused_computation.1147
    Allocation type: HLO temp

  10. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/9/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[2048,4096]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1354 = fusion(param.1081, convert.1432, fusion.417, fusion.3822), kind=kOutput, calls=fused_computation.1099
    Allocation type: HLO temp
    ==========================

  11. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/cond/branch_0_fun/mul" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/transform.py" source_line=405
    Shape: f32[4096,2048]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.657 = fusion(get-tuple-element.1410, subtract.201, subtract.200, get-tuple-element.1409, ...(+2)), kind=kLoop, calls=fused_computation.402
    Allocation type: HLO temp
    ==========================

  12. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/9/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[4096,2048]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1355 = fusion(param.1077, convert.1432, fusion.265, get-tuple-element.6886), kind=kOutput, calls=fused_computation.1100
    Allocation type: HLO temp
    ==========================

  13. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/cond/branch_0_fun/mul" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/transform.py" source_line=405
    Shape: f32[2048,4096]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.654 = fusion(get-tuple-element.1418, subtract.201, subtract.200, get-tuple-element.1417, ...(+2)), kind=kLoop, calls=fused_computation.399
    Allocation type: HLO temp
    ==========================

  14. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/14/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[4096,2048]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1397 = fusion(param.657, convert.1432, fusion.285, get-tuple-element.6844), kind=kOutput, calls=fused_computation.1142
    Allocation type: HLO temp
    ==========================

  15. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/11/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[4096,2048]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1403 = fusion(param.597, convert.1432, fusion.273, get-tuple-element.6838), kind=kOutput, calls=fused_computation.1148
    Allocation type: HLO temp
    ==========================

  16. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/14/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[2048,4096]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1396 = fusion(param.661, convert.1432, fusion.397, fusion.3897), kind=kOutput, calls=fused_computation.1141
    Allocation type: HLO temp
    ==========================

  17. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/10/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[2048,4096]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1404 = fusion(param.581, convert.1432, fusion.413, fusion.3837), kind=kOutput, calls=fused_computation.1149
    Allocation type: HLO temp
    ==========================

  18. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/10/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[4096,2048]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1405 = fusion(param.577, convert.1432, fusion.269, get-tuple-element.6836), kind=kOutput, calls=fused_computation.1150
    Allocation type: HLO temp
    ==========================

  19. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/15/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
    Shape: f32[4096,2048]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.1395 = fusion(param.677, convert.1432, fusion.289, get-tuple-element.6846), kind=kOutput, calls=fused_computation.1140
    Allocation type: HLO temp
    ==========================

  20. Size: 32.00M
    Operator: op_name="pjit(step_fn)/jit(main)/cond/branch_0_fun/mul" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/transform.py" source_line=405
    Shape: f32[4096,2048]{1,0:T(8,128)}
    Unpadded size: 32.00M
    XLA label: fusion.651 = fusion(get-tuple-element.1450, subtract.201, subtract.200, get-tuple-element.1449, ...(+2)), kind=kLoop, calls=fused_computation.396
    Allocation type: HLO temp
    ==========================

I have used the mesh-transformer-jax library before for fine-tuning the 6B model on the same machine, so it should be possible. If you could help with this issue that'll be truly appreciated!

Aqt?

Hey do you have any experience with this library?
I'm thinking about doing quantized training, and it's not clear how to use it.

gptj serve issue on a TPU-VM

Hello I'm using the library on a TPU-VM (v3-8). The software version is tpu-vm-base.

Running the command

python gptj_serve.py

I got the following error:

Traceback (most recent call last):
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 611, in connect
    sock = self.retry.call_with_retry(
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/retry.py", line 46, in call_with_retry
    return do()
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 612, in <lambda>
    lambda: self._connect(), lambda error: self.disconnect(error)
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 677, in _connect
    raise err
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 665, in _connect
    sock.connect(socket_address)
ConnectionRefusedError: [Errno 111] Connection refused

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/qj213/JAXSeq/examples/gptj_serve.py", line 98, in <module>
    inference_server = InferenceServerMP(
  File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 21, in __init__
    self.Q = initalize_server(self, super().__getattribute__('r'), cache_cls, args, kwargs)
  File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 79, in initalize_server
    build_method(Config.init_message, r, Q)(self)
  File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 44, in call_method
    request_id = int(r.incr('request_id_counter'))
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/commands/core.py", line 1831, in incrby
    return self.execute_command("INCRBY", name, amount)
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/client.py", line 1235, in execute_command
    conn = self.connection or pool.get_connection(command_name, **options)
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 1387, in get_connection
    connection.connect()
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 617, in connect
    raise ConnectionError(self._error_message(e))
redis.exceptions.ConnectionError: Error 111 connecting to localhost:6379. Connection refused.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/qj213/JAXSeq/examples/gptj_serve.py", line 98, in <module>
    inference_server = InferenceServerMP(
  File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 21, in __init__
    self.Q = initalize_server(self, super().__getattribute__('r'), cache_cls, args, kwargs)
  File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 79, in initalize_server
    build_method(Config.init_message, r, Q)(self)
  File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 44, in call_method
    request_id = int(r.incr('request_id_counter'))
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/commands/core.py", line 1831, in incrby
    return self.execute_command("INCRBY", name, amount)
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/client.py", line 1235, in execute_command
    conn = self.connection or pool.get_connection(command_name, **options)
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 1387, in get_connection
    connection.connect()
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 617, in connect
    raise ConnectionError(self._error_message(e))
redis.exceptions.ConnectionError: Error 111 connecting to localhost:6379. Connection refused.
using mesh shape: (1, 8)
full mesh: [[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)
  TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)
  TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0)
  TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)
  TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0)
  TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1)
  TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0)
  TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]]
current process index 0, in position [0, 0] of [1, 1]
tcmalloc: large alloc 24203542528 bytes == 0x560314c42000 @  0x7f615c34c680 0x7f615c36d824 0x560289f3e53b 0x560289f7f0ba 0x56028a055a58 0x560289fb148d 0x560289e8b328 0x56028a06b66d 0x560289fb1825 0x560289f0f2da 0x560289fa6fe3 0x560289fa8709 0x560289f0e73d 0x560289fa7be4 0x560289f0e088 0x560289fa6fe3 0x560289fa7d24 0x560289f0e73d 0x560289fa6fe3 0x560289fa7d24 0x560289f92a2e 0x560289f9c429 0x560289f676ab 0x560289f56359 0x560289fe7e7a 0x560289fa7be4 0x560289f5630a 0x560289fe7e7a 0x560289fa7be4 0x560289f0f2da 0x560289fa6fe3
unmatches keys: set()
Process Process-2:
Traceback (most recent call last):
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/managers.py", line 802, in _callmethod
    conn = self._tls.connection
AttributeError: 'ForkAwareLocal' object has no attribute 'connection'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 60, in server_process
    request_id, method, args, kwargs = Q.get()
  File "<string>", line 2, in get
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/managers.py", line 806, in _callmethod
    self._connect()
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/managers.py", line 793, in _connect
    conn = self._Client(self._token.address, authkey=self._authkey)
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/connection.py", line 507, in Client
    c = SocketClient(address)
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/connection.py", line 635, in SocketClient
    s.connect(address)
ConnectionRefusedError: [Errno 111] Connection refused

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 73, in server_process
    raise Exception
Exception

Any idea about how I should fix this? Thank you!!!

Details about the partitioning rules

Hi,
great work!

I have been looking into the set of partition rules that you use to parallelize the models.

Would you mind explaining the idea behind these specific partitioning schemes? Is there any particular heuristic that you are following? For example, I have notice that you often invert the "mp" and "fsdp" axis for consecutive layers. Is this to minimize communication costs, or is there any other reason?

Thanks a lot!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.