Skip to content

Commit

Permalink
Get device index from local rank if multi-client, otherwise use the c…
Browse files Browse the repository at this point in the history
…urrent device. (#6405)

* Fix random generator

* Get device index from local rank if multi-client, otherwise use current device.

Co-authored-by: oneflow-ci-bot <[email protected]>
  • Loading branch information
hjchen2 and oneflow-ci-bot authored Sep 26, 2021
1 parent 6273773 commit 89bbc5b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
4 changes: 2 additions & 2 deletions oneflow/core/framework/random_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Maybe<Generator> DefaultCUDAGenerator(int device_index) {
static std::vector<std::once_flag> init_flags(device_count);
static std::vector<std::shared_ptr<Generator>> default_cuda_generator(device_count);

if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }
if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); }
CHECK_OR_RETURN(device_index >= 0 && device_index < device_count)
<< "Invalid device index " << device_index;
std::call_once(init_flags[device_index], [&]() {
Expand All @@ -91,7 +91,7 @@ Maybe<Generator> MakeCPUGenerator() {

#ifdef WITH_CUDA
Maybe<Generator> MakeCUDAGenerator(int device_index) {
if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }
if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); }
CHECK_OR_RETURN(device_index >= 0 && device_index < detail::GetCudaDeviceCount())
<< "Invalid device index " << device_index;
return std::make_shared<Generator>(
Expand Down
35 changes: 25 additions & 10 deletions oneflow/core/framework/random_generator_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,9 @@ int GetThreadNum(const cudaDeviceProp& prop) {
}
}

Maybe<void> CUDASynchronize(int device_index) {
Maybe<void> CUDASynchronize() {
// Synchronize cuda device to avoid state been modified in random kernels.
JUST(CPUSynchronize());
OF_CUDA_CHECK(cudaSetDevice(device_index));
OF_CUDA_CHECK(cudaDeviceSynchronize());
return Maybe<void>::Ok();
}
Expand All @@ -161,25 +160,29 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(uint64_t seed, int device_index)
OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_index));
max_block_num_ = prop.multiProcessorCount;
max_thread_num_ = GetThreadNum(prop);
OF_CUDA_CHECK(cudaSetDevice(device_index));

CudaCurrentDeviceGuard dev_guard(device_index);
OF_CUDA_CHECK(
cudaMalloc(&curand_states_, max_block_num_ * max_thread_num_ * sizeof(curandState)));
detail::InitCurandStates(seed, max_block_num_, max_thread_num_, curand_states_);
}

CUDAGeneratorImpl::~CUDAGeneratorImpl() {
CHECK_JUST(CUDASynchronize(this->device_index()));
CudaCurrentDeviceGuard dev_guard(this->device_index());
CHECK_JUST(CUDASynchronize());
OF_CUDA_CHECK(cudaFree(curand_states_));
}

void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
CHECK_JUST(CUDASynchronize(this->device_index()));
CudaCurrentDeviceGuard dev_guard(this->device_index());
CHECK_JUST(CUDASynchronize());
seed_ = seed;
detail::InitCurandStates(seed_, max_block_num_, max_thread_num_, curand_states_);
}

Maybe<Tensor> CUDAGeneratorImpl::GetState() const {
JUST(CUDASynchronize(this->device_index()));
CudaCurrentDeviceGuard dev_guard(this->device_index());
JUST(CUDASynchronize());
int64_t state_size = max_block_num_ * max_thread_num_ * sizeof(curandState);
int64_t total_size = state_size + sizeof(int64_t);
const auto& device = JUST(Device::New("cpu"));
Expand Down Expand Up @@ -207,7 +210,8 @@ Maybe<void> CUDAGeneratorImpl::SetState(const std::shared_ptr<Tensor>& tensor_st
<< total_size << ", but got " << tensor_state->shape()->elem_cnt();
}

JUST(CUDASynchronize(this->device_index()));
CudaCurrentDeviceGuard dev_guard(this->device_index());
JUST(CUDASynchronize());
const auto& callback = std::make_shared<std::function<void(uint64_t)>>([&](uint64_t of_blob_ptr) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
const int8_t* data = of_blob->blob().dptr<int8_t>();
Expand Down Expand Up @@ -398,16 +402,27 @@ Maybe<CPUGeneratorImpl> MakeGeneratorImpl<CPUGeneratorImpl>(uint64_t seed, int d
}

#ifdef WITH_CUDA

int GetCudaDeviceIndex() {
int cuda_device_index = 0;
if (CHECK_JUST(GlobalMultiClientEnv())) {
cuda_device_index = GlobalProcessCtx::LocalRank();
} else {
OF_CUDA_CHECK(cudaGetDevice(&cuda_device_index));
}
return cuda_device_index;
}

int GetCudaDeviceCount() {
/* static */ int cuda_device_count;
OF_CUDA_CHECK(cudaSetDevice(GlobalProcessCtx::LocalRank()));
/* static */ int cuda_device_count = 0;
CudaCurrentDeviceGuard dev_guard(detail::GetCudaDeviceIndex());
OF_CUDA_CHECK(cudaGetDeviceCount(&cuda_device_count));
return cuda_device_count;
}

template<>
DeviceKey MakeDeviceKey<CUDAGeneratorImpl>(int device_index) {
if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }
if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); }
DeviceKey device_key;
device_key.device_type = DeviceType::kGPU;
device_key.device_index = device_index;
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/framework/random_generator_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class CUDAGeneratorImpl : public DeviceGeneratorImpl {

namespace detail {

int GetCudaDeviceIndex();
int GetCudaDeviceCount();

void InitCurandStates(uint64_t seed, int32_t block_num, int32_t thread_num, curandState* states);
Expand Down

0 comments on commit 89bbc5b

Please sign in to comment.