sea-snell / jaxseq Goto Github PK
View Code? Open in Web Editor NEWTrain very large language models in Jax.
License: MIT License
Train very large language models in Jax.
License: MIT License
Or, if there's any other related resources you could point me toπ
Hello,
I came across this repository while searching for ways to train a 10~20B scale causal language model (similar to GPT) using TPUv4 pod.
In the README, I noticed that the recommended data format is {'train': [{'in_text', 'out_text'}, ...], 'eval': [{'in_text', 'out_text'}, ...]}. Could you please explain the reasoning behind this format?
As far as I understand, GPT-2 or GPT-J only requires single texts as input and output for training, instead of the input-output pairs mentioned in the data format. I appreciate any clarification you can provide on this matter.
Thank you.
Hi,
I'm trying to fine-tune a gpt-j-6B model on a TPU v3-8 instance. I get an OOM error despite doing model parallelism. Could you please let me know what could be wrong and what possible fixes I can use?
Here's the command I used:
python examples/gptj_train.py $expname EleutherAI/gpt-j-6B /home/qj213/data/debug.json --use-wandb --wandb-project $expname --model-p-shape 8 --data-p-shape 1 --train-bsize 1 --inference-bsize 1
And here's the error I received:
Traceback (most recent call last):
File "/home/qj213/JAXSeq/examples/gptj_train.py", line 265, in
tyro.cli(main)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/tyro/_cli.py", line 114, in cli
_cli_impl(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/tyro/_cli.py", line 293, in _cli_impl
out, consumed_keywords = _calling.call_from_args(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/tyro/_calling.py", line 192, in call_from_args
return unwrapped_f(*args, **kwargs), consumed_keywords # type: ignore
File "/home/qj213/JAXSeq/examples/gptj_train.py", line 239, in main
trainer, inference = train_loop(
File "/home/qj213/JAXSeq/src/seq2seq_train.py", line 108, in train_loop
_, info, trainer = trainer.train_step(items, new_rng)
File "/home/qj213/JAXSeq/src/core.py", line 59, in train_step
loss, info, new_params, new_optim_state = self.train_fn(self.params, self.optim_state, rng_key, batch)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/experimental/pjit.py", line 138, in wrapped
return _python_pjit_helper(infer_params, *args, **kwargs)[0]
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/experimental/pjit.py", line 130, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/core.py", line 328, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/core.py", line 331, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/core.py", line 698, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/experimental/pjit.py", line 919, in _pjit_call_impl
compiled = _pjit_lower(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 3072, in compile
self._executable = MeshExecutable.from_hlo(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 3251, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/_src/dispatch.py", line 1056, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/_src/profiler.py", line 313, in wrapper
return func(*args, **kwargs)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/jax/_src/dispatch.py", line 996, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 17.28G of 15.48G hbm. Exceeded hbm capacity by 1.80G.Total hbm usage >= 17.80G:
reserved 530.00M
program 5.78G
arguments 11.51GOutput size 11.51G; shares 11.51G with arguments.
Program hbm requirement 5.78G:
global 1.35M
scoped 7.52M
HLO temp 5.77G (100.0% utilization: Unpadded (5.76G) Padded (5.76G), 0.2% fragmentation (12.13M))Largest program allocations in hbm:
Size: 128.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/lm_head/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[4096,8192]{1,0:T(8,128)}
Unpadded size: 128.00M
XLA label: fusion.1349 = fusion(param.521, convert.1432, fusion.116, fusion.341), kind=kOutput, calls=fused_computation.1094
Allocation type: HLO tempSize: 128.00M
Operator: op_name="pjit(step_fn)/jit(main)/cond/branch_0_fun/mul" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/transform.py" source_line=405
Shape: f32[4096,8192]{1,0:T(8,128)}
Unpadded size: 128.00M
XLA label: fusion.489 = fusion(get-tuple-element.1378, subtract.201, subtract.200, get-tuple-element.1377, ...(+2)), kind=kLoop, calls=fused_computation.234
Allocation type: HLO tempSize: 128.00M
Operator: op_name="pjit(step_fn)/jit(main)/cond/branch_0_fun/mul" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/transform.py" source_line=405
Shape: f32[8192,4096]{1,0:T(8,128)}
Unpadded size: 128.00M
XLA label: fusion.486 = fusion(get-tuple-element.2510, subtract.201, subtract.200, get-tuple-element.2509, ...(+2)), kind=kLoop, calls=fused_computation.231
Allocation type: HLO tempSize: 128.00M
Operator: op_name="pjit(step_fn)/jit(main)/add" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/wrappers.py" source_line=327
Shape: f32[8192,4096]{1,0:T(8,128)}
Unpadded size: 128.00M
XLA label: fusion.1348 = fusion(param.1087, fusion.229, convert.1432), kind=kLoop, calls=fused_computation.1093
Allocation type: HLO tempSize: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/13/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[4096,2048]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1399 = fusion(param.637, convert.1432, fusion.281, get-tuple-element.6842), kind=kOutput, calls=fused_computation.1144
Allocation type: HLO tempSize: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/13/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[2048,4096]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1398 = fusion(param.641, convert.1432, fusion.401, fusion.3882), kind=kOutput, calls=fused_computation.1143
Allocation type: HLO tempSize: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/12/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[2048,4096]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1400 = fusion(param.621, convert.1432, fusion.405, fusion.3867), kind=kOutput, calls=fused_computation.1145
Allocation type: HLO tempSize: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/12/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[4096,2048]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1401 = fusion(param.617, convert.1432, fusion.277, get-tuple-element.6840), kind=kOutput, calls=fused_computation.1146
Allocation type: HLO tempSize: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/11/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[2048,4096]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1402 = fusion(param.601, convert.1432, fusion.409, fusion.3852), kind=kOutput, calls=fused_computation.1147
Allocation type: HLO tempSize: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/9/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[2048,4096]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1354 = fusion(param.1081, convert.1432, fusion.417, fusion.3822), kind=kOutput, calls=fused_computation.1099
Allocation type: HLO temp
==========================Size: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/cond/branch_0_fun/mul" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/transform.py" source_line=405
Shape: f32[4096,2048]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.657 = fusion(get-tuple-element.1410, subtract.201, subtract.200, get-tuple-element.1409, ...(+2)), kind=kLoop, calls=fused_computation.402
Allocation type: HLO temp
==========================Size: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/9/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[4096,2048]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1355 = fusion(param.1077, convert.1432, fusion.265, get-tuple-element.6886), kind=kOutput, calls=fused_computation.1100
Allocation type: HLO temp
==========================Size: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/cond/branch_0_fun/mul" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/transform.py" source_line=405
Shape: f32[2048,4096]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.654 = fusion(get-tuple-element.1418, subtract.201, subtract.200, get-tuple-element.1417, ...(+2)), kind=kLoop, calls=fused_computation.399
Allocation type: HLO temp
==========================Size: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/14/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[4096,2048]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1397 = fusion(param.657, convert.1432, fusion.285, get-tuple-element.6844), kind=kOutput, calls=fused_computation.1142
Allocation type: HLO temp
==========================Size: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/11/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[4096,2048]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1403 = fusion(param.597, convert.1432, fusion.273, get-tuple-element.6838), kind=kOutput, calls=fused_computation.1148
Allocation type: HLO temp
==========================Size: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/14/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[2048,4096]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1396 = fusion(param.661, convert.1432, fusion.397, fusion.3897), kind=kOutput, calls=fused_computation.1141
Allocation type: HLO temp
==========================Size: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/10/mlp/fc_out/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[2048,4096]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1404 = fusion(param.581, convert.1432, fusion.413, fusion.3837), kind=kOutput, calls=fused_computation.1149
Allocation type: HLO temp
==========================Size: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/10/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[4096,2048]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1405 = fusion(param.577, convert.1432, fusion.269, get-tuple-element.6836), kind=kOutput, calls=fused_computation.1150
Allocation type: HLO temp
==========================Size: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/transpose(jvp(FlaxGPTJForCausalLMModule))/transformer/h/15/mlp/fc_in/transpose[permutation=(1, 0)]" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/flax/linen/linear.py" source_line=196
Shape: f32[4096,2048]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.1395 = fusion(param.677, convert.1432, fusion.289, get-tuple-element.6846), kind=kOutput, calls=fused_computation.1140
Allocation type: HLO temp
==========================Size: 32.00M
Operator: op_name="pjit(step_fn)/jit(main)/cond/branch_0_fun/mul" source_file="/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/optax/_src/transform.py" source_line=405
Shape: f32[4096,2048]{1,0:T(8,128)}
Unpadded size: 32.00M
XLA label: fusion.651 = fusion(get-tuple-element.1450, subtract.201, subtract.200, get-tuple-element.1449, ...(+2)), kind=kLoop, calls=fused_computation.396
Allocation type: HLO temp
==========================
I have used the mesh-transformer-jax library before for fine-tuning the 6B model on the same machine, so it should be possible. If you could help with this issue that'll be truly appreciated!
Hey do you have any experience with this library?
I'm thinking about doing quantized training, and it's not clear how to use it.
Hello I'm using the library on a TPU-VM (v3-8). The software version is tpu-vm-base.
Running the command
python gptj_serve.py
I got the following error:
Traceback (most recent call last):
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 611, in connect
sock = self.retry.call_with_retry(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/retry.py", line 46, in call_with_retry
return do()
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 612, in <lambda>
lambda: self._connect(), lambda error: self.disconnect(error)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 677, in _connect
raise err
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 665, in _connect
sock.connect(socket_address)
ConnectionRefusedError: [Errno 111] Connection refused
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/qj213/JAXSeq/examples/gptj_serve.py", line 98, in <module>
inference_server = InferenceServerMP(
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 21, in __init__
self.Q = initalize_server(self, super().__getattribute__('r'), cache_cls, args, kwargs)
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 79, in initalize_server
build_method(Config.init_message, r, Q)(self)
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 44, in call_method
request_id = int(r.incr('request_id_counter'))
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/commands/core.py", line 1831, in incrby
return self.execute_command("INCRBY", name, amount)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/client.py", line 1235, in execute_command
conn = self.connection or pool.get_connection(command_name, **options)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 1387, in get_connection
connection.connect()
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 617, in connect
raise ConnectionError(self._error_message(e))
redis.exceptions.ConnectionError: Error 111 connecting to localhost:6379. Connection refused.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/qj213/JAXSeq/examples/gptj_serve.py", line 98, in <module>
inference_server = InferenceServerMP(
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 21, in __init__
self.Q = initalize_server(self, super().__getattribute__('r'), cache_cls, args, kwargs)
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 79, in initalize_server
build_method(Config.init_message, r, Q)(self)
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 44, in call_method
request_id = int(r.incr('request_id_counter'))
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/commands/core.py", line 1831, in incrby
return self.execute_command("INCRBY", name, amount)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/client.py", line 1235, in execute_command
conn = self.connection or pool.get_connection(command_name, **options)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 1387, in get_connection
connection.connect()
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 617, in connect
raise ConnectionError(self._error_message(e))
redis.exceptions.ConnectionError: Error 111 connecting to localhost:6379. Connection refused.
using mesh shape: (1, 8)
full mesh: [[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)
TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0)
TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)
TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0)
TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1)
TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0)
TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]]
current process index 0, in position [0, 0] of [1, 1]
tcmalloc: large alloc 24203542528 bytes == 0x560314c42000 @ 0x7f615c34c680 0x7f615c36d824 0x560289f3e53b 0x560289f7f0ba 0x56028a055a58 0x560289fb148d 0x560289e8b328 0x56028a06b66d 0x560289fb1825 0x560289f0f2da 0x560289fa6fe3 0x560289fa8709 0x560289f0e73d 0x560289fa7be4 0x560289f0e088 0x560289fa6fe3 0x560289fa7d24 0x560289f0e73d 0x560289fa6fe3 0x560289fa7d24 0x560289f92a2e 0x560289f9c429 0x560289f676ab 0x560289f56359 0x560289fe7e7a 0x560289fa7be4 0x560289f5630a 0x560289fe7e7a 0x560289fa7be4 0x560289f0f2da 0x560289fa6fe3
unmatches keys: set()
Process Process-2:
Traceback (most recent call last):
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/managers.py", line 802, in _callmethod
conn = self._tls.connection
AttributeError: 'ForkAwareLocal' object has no attribute 'connection'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 60, in server_process
request_id, method, args, kwargs = Q.get()
File "<string>", line 2, in get
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/managers.py", line 806, in _callmethod
self._connect()
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/managers.py", line 793, in _connect
conn = self._Client(self._token.address, authkey=self._authkey)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/connection.py", line 507, in Client
c = SocketClient(address)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/connection.py", line 635, in SocketClient
s.connect(address)
ConnectionRefusedError: [Errno 111] Connection refused
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 73, in server_process
raise Exception
Exception
Any idea about how I should fix this? Thank you!!!
is there any way for converting checkpoint for jax to huggingface format?
Hi,
great work!
I have been looking into the set of partition rules that you use to parallelize the models.
Would you mind explaining the idea behind these specific partitioning schemes? Is there any particular heuristic that you are following? For example, I have notice that you often invert the "mp" and "fsdp" axis for consecutive layers. Is this to minimize communication costs, or is there any other reason?
Thanks a lot!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. πππ
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google β€οΈ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.