Comments (6)
my latest thinking is we don't need any new args to Stage() ctor, if we continue to assume no skip connections.
if we want to enable skip connections (later) we could add the args i proposed in the RFC with an adjustment. instead of 'args_rank' it would be 'args_stage', so e.g. it would tell you 'send arg 0 to stage 3, send arg1 to stage 4' but Stage would not know which rank owned those stages yet.
The change i would propose for now is that when you ask the stage for 'get_*send_ops', it has an optional argument for stage mapping. if None, it assumes a linear-modulo-pp_size mapping. If not None, it uses the mapping to determine which pp_rank a stage is on.
think this through and see if it makes sense, its just off the top of my head so could have issues.
from pytorch.
One annoying thing is that it depends on which schedule you use what 'next stage' will be. It would be ideal if we could later-bind that information when the stages and schedule are used together.
Maybe the stage can have a map of stage-id to rank given to it by the schedule either during schedule init or during each call to get_*_ops?
cc @H-Huang
from pytorch.
Yeah I agree, there will need to be a stage-id to rank mapping for the correct comm ops. There are currently a few assumptions baked into the code that need to be updated:
Assumption 1) Stage id to rank mapping in looped cases is always stage_ids = range(rank, total_num_stages, local_num_stages)
. We can fix this by adding stage-id to rank mapping.
Assumption 2) You always receive from stage_id - 1
and send to stage_id + 1
. We can fix this by the optional arguments mentioned above.
from pytorch.
@H-Huang one more design consideration is how we should deal with the communication between the two stages at the bottom of the 'V' that are on the same physical rank.
e.g. say stage 3 needs to send outputs to 4 and recv grads from 4.
- can we use NCCL for this use case today until we decide to optimize it? (Does nccl support sending/recving from the same rank to itself?)
- if we want to avoid doing a comm op, how can we cleanly let stage3 know about stage4 and share a tensor? perhaps the schedule code itself needs to do this by passing the output tensor from 3 as an input to 4? (and skip generating send/recv ops).
from pytorch.
@wconstab I'm not sure about (1) I can test it out, but I think a clean way of doing it is to just check a condition in get_*_send_ops
of whether the rank you are sending to is yourself, if so then just automatically update the respective recv_buffers (much like what should be updated from a get_*_recv_op
). I think all of the changes can remain in the Stage
class (the stage would just somehow need to know the other stages) without any changes to Schedule implementation. The send / recv ops will just return empty lists in this case (thus batch_isend_irecv
will be a no-op)
from pytorch.
the stage would just somehow need to know the other stages)
what's your proposal for how to let stages know about other stages?
- during Stage init we could not easily pass all other stages, so lets rule this out
- (a) a new method on Stage to 'register peer stages' could be called by the schedule at init time, for all stages on the same rank
- (b) passing the recv 'Stage' object to Stage.get_fwd_send_ops(recv_stage) might be another way, during schedule step()
I guess (a) is pretty clean if we can do it in a schedule base class. And we should define the fallback too. If this registration is not performed, what will happen?
- ranks will fall back to using nccl to send/recv to local same rank?
- or will this error?
from pytorch.
Related Issues (20)
- Investigate alternatives to remove mask from triton sort kernel
- `test_dummy_mha_with_nt_cuda` fails on `sm70`, `sm75`
- Dynamo export: unsupported FX node 'aten.expm1.default' HOT 4
- Error: CUDA error: CUDA_ERROR_INVALID_VALUE cuMemcpyDtoH failed with Halide GPU Backend HOT 2
- Dynamo export: Fake tensor broadcast error HOT 2
- UNSTABLE trunk / linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test / test (distributed) HOT 1
- UNSTABLE trunk / linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test / test (distributed) HOT 1
- UNSTABLE trunk / linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test / test (distributed) HOT 1
- How to Convert pytorch qat model to tensorrt HOT 1
- [AOTI] AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
- RVV support in PyTorch
- Libtorch 2.3.1 requires Glibc_2.28 and can't be used on ubuntu 18. HOT 4
- my accuracy is not increasing HOT 1
- Support for 5-D output image tensors in `Col2Im`. HOT 1
- ChainedScheduler fail on CosineAnnealingWarmRestarts
- OSError: [WinError 126] The specified module could not be found. Error loading "\.venv\Lib\site-packages\torch\lib\fbgemm.dll" or one of its dependencies. HOT 1
- scaled_dot_product_attention fails on Ampere arch with head_dim > 128 HOT 2
- dynamo should recompile when a tensor subclass's inner tensor changes HOT 1
- torch._dynamo.exc.Unsupported: call_method NNModuleVariable() state_dict [] {}
- [custom_op] torch.library.define should be able to auto-infer schema.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.