graph_sampling's People
Forkers
gpzlx1graph_sampling's Issues
[BUG] DeviceCacheAlloc is not safe.
inline void* DeviceCacheAlloc(size_t temp_storage_bytes) {
c10::Allocator* cuda_allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr _temp_data = cuda_allocator->allocate(temp_storage_bytes);
return _temp_data.get();
}
DeviceCacheAlloc
is not safe. The memory space the function return is out of the control of pytorch memory managment, which maybe overwritten by pytorch at some point, even the space are used. The reason is that cuda_allocator
, _temp_data
and the function return value have inconsistent lifetime
Using the following code in place of .DeviceCacheAlloc
.
c10::Allocator* cuda_allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr _temp_data = cuda_allocator->allocate(temp_storage_bytes);
void* ptr = _temp_data.get();
[BUG] Backward of SDDMM (u_mul_v) produces nan values
Backward of SpMM and SDDMM is supported in branch dev_spmm
.
However, In pass sampler, the backward of gs.ops.u_mul_v(subA, u_feats @ W_2, v_feats @ W_2)
, i.e. (dX = gspmm(_gidx, "mul", "sum", Y, dZ, rev_format)
), produces nan values while its inputs have no nan vlaues.
To reproduce:
$ git checkout origin/dev_spmm
$ build and install the project
$ cd examples/pass
$ python train_minibatch.py
Namespace(device='cuda', use_uva=None, dataset='reddit', batchsize=512, samples='10,10', num_workers=0)
Graph(num_nodes=232965, num_edges=114848857,
ndata_schemes={}
edata_schemes={})
WARNING: Logging before InitGoogleLogging() is written to STDERR
I20230207 03:35:00.528640 18177 graph.cc:19] Loaded CSC with 232965 nodes and 114848857 edges
Check load successfully: [None, None, tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0'), tensor([ 0, 2205, 2360, ..., 114848225, 114848365,
114848857], device='cuda:0'), tensor([225202, 177307, 107546, ..., 232594, 232634, 232964], device='cuda:0')]
memory allocated before training: 2.2396583557128906 GB
0%| | 0/300 [00:00<?, ?it/s]
/home/ubuntu/aws_projects/graph_sampling/examples/pass/train_minibatch.py:148: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
with torch.autograd.detect_anomaly():
/home/ubuntu/aws_projects/graph_sampling/examples/pass/train_minibatch.py:156: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
with torch.autograd.detect_anomaly():
/home/ubuntu/anaconda3/envs/dgl/lib/python3.9/site-packages/torch/autograd/__init__.py:173: UserWarning: Error detected in GSDDMMBackward. Traceback of forward call that caused the error:
File "/home/ubuntu/aws_projects/graph_sampling/examples/pass/train_minibatch.py", line 247, in <module>
train(dataset, args)
File "/home/ubuntu/aws_projects/graph_sampling/examples/pass/train_minibatch.py", line 125, in train
input_nodes, output_nodes, blocks, loss_tuple = compiled_func(
File "/home/ubuntu/aws_projects/graph_sampling/examples/pass/train_minibatch.py", line 33, in matrix_sampler
att2 = torch.sum(gs.ops.u_mul_v(subA, u_feats @ W_2,
File "/home/ubuntu/anaconda3/envs/dgl/lib/python3.9/site-packages/gs-0.1-py3.9.egg/gs/ops/sddmm.py", line 115, in func
return gsddmm(g, binary_op, x, y,
File "/home/ubuntu/anaconda3/envs/dgl/lib/python3.9/site-packages/gs-0.1-py3.9.egg/gs/ops/sddmm.py", line 72, in gsddmm
return gsddmm_internal(
File "/home/ubuntu/anaconda3/envs/dgl/lib/python3.9/site-packages/gs-0.1-py3.9.egg/gs/ops/sparse.py", line 286, in gsddmm
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target, on_format)
(Triggered internally at /opt/conda/conda-bld/pytorch_1656352657443/work/torch/csrc/autograd/python_anomaly_mode.cpp:102.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
0%| | 0/300 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/home/ubuntu/aws_projects/graph_sampling/examples/pass/train_minibatch.py", line 247, in <module>
train(dataset, args)
File "/home/ubuntu/aws_projects/graph_sampling/examples/pass/train_minibatch.py", line 157, in train
sample_loss.backward()
File "/home/ubuntu/anaconda3/envs/dgl/lib/python3.9/site-packages/torch/_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/ubuntu/anaconda3/envs/dgl/lib/python3.9/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'GSDDMMBackward' returned nan values in its 0th output.
[BUG] Each iteration produces the same sampling result
The _SampleSubIndicesKernelFusedWithReplace
uses fixed seed, making the results almost the same for each call. This can lead to poor training accuracy
template <typename IdType>
__global__ void _SampleSubIndicesKernelFusedWithReplace(IdType* sub_indices,
IdType* indptr, IdType* indices,
IdType* sub_indptr,
IdType* column_ids, int64_t size) {
int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
const uint64_t random_seed = 7777777; // There's a problem here
curandState rng;
curand_init(random_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
while (row < size) {
int64_t col = column_ids[row];
int64_t in_start = indptr[col];
int64_t out_start = sub_indptr[row];
int64_t degree = indptr[col + 1] - indptr[col];
int64_t fanout = sub_indptr[row + 1] - sub_indptr[row];
int64_t tid = threadIdx.x;
while (tid < fanout) {
// Sequential Sampling
const int64_t edge = tid % degree;
// Random Sampling
// const int64_t edge = curand(&rng) % degree;
sub_indices[out_start + tid] = indices[in_start + edge];
tid += blockDim.x;
}
row += gridDim.x * blockDim.y;
}
}
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.