Skip to content

Commit

Permalink
Add CudnnConvAlgoCache (#3649)
Browse files Browse the repository at this point in the history
* Add CudnnConvAlgoCache

* refine

Former-commit-id: a2af59e
  • Loading branch information
liujuncheng authored Oct 5, 2020
1 parent 763ad3c commit b463722
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 2 deletions.
55 changes: 53 additions & 2 deletions oneflow/core/device/cudnn_conv_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,60 @@ perf_t GetBestAlgorithm(const CudnnConvArgs& args, CudnnConvResource* res,
return perf_vec.at(found_algo_idx);
}

template<typename perf_t>
perf_t CudnnConvAlgoGetOrInfer(const CudnnConvParams& params,
const std::function<perf_t(const CudnnConvParams&)>& InferFn,
CudnnConvAlgoCache::Store<perf_t>* store, std::mutex* mutex) {
const size_t cache_size = Global<ResourceDesc, ForSession>::Get()->thread_local_cache_max_size();
auto InferWithCache = [&](const CudnnConvParams& p) -> perf_t {
CudnnConvParams params_without_ws = p;
params_without_ws.max_ws_size = 0;
std::unique_lock<std::mutex> lock(*mutex);
const auto& key_it = store->find(params_without_ws);
if (key_it != store->cend()) {
const auto& perf_it = std::find_if(
key_it->second.cbegin(), key_it->second.cend(),
[&](const std::pair<size_t, perf_t>& pair) {
// There might be a case that only memory size pair.second.memory was required for the
// best algorithm even though a workspace pair.first supplied
return pair.second.memory <= p.max_ws_size /* for memory safety */
&& pair.first >= p.max_ws_size /* a case with larger workspace infered before */;
});
if (perf_it != key_it->second.cend()) { return perf_it->second; }
}
perf_t perf = InferFn(p);
(*store)[params_without_ws].push_back(std::make_pair(p.max_ws_size, perf));
return perf;
};
return ThreadLocalCachedCall(cache_size, InferWithCache, params);
}

} // namespace

template<>
cudnnConvolutionFwdAlgoPerf_t CudnnConvAlgoCache::Remember(
const CudnnConvParams& params,
const std::function<cudnnConvolutionFwdAlgoPerf_t(const CudnnConvParams&)>& InferFn) {
return CudnnConvAlgoGetOrInfer<cudnnConvolutionFwdAlgoPerf_t>(params, InferFn, &fwd_algo_store_,
&fwd_algo_store_mutex_);
}

template<>
cudnnConvolutionBwdDataAlgoPerf_t CudnnConvAlgoCache::Remember(
const CudnnConvParams& params,
const std::function<cudnnConvolutionBwdDataAlgoPerf_t(const CudnnConvParams&)>& InferFn) {
return CudnnConvAlgoGetOrInfer<cudnnConvolutionBwdDataAlgoPerf_t>(
params, InferFn, &bwd_data_algo_store_, &bwd_data_algo_store_mutex_);
}

template<>
cudnnConvolutionBwdFilterAlgoPerf_t CudnnConvAlgoCache::Remember(
const CudnnConvParams& params,
const std::function<cudnnConvolutionBwdFilterAlgoPerf_t(const CudnnConvParams&)>& InferFn) {
return CudnnConvAlgoGetOrInfer<cudnnConvolutionBwdFilterAlgoPerf_t>(
params, InferFn, &bwd_filter_algo_store_, &bwd_filter_algo_cache_mutex_);
}

CudnnConvDesc::~CudnnConvDesc() { OF_CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(val_)); }

CudnnConvDesc::CudnnConvDesc(const DataType compute_type, const DataType data_type,
Expand Down Expand Up @@ -405,8 +457,7 @@ perf_t FindCudnnConvAlgorithmWithResource(CudnnConvArgs* args, CudnnConvResource
}
return GetBestAlgorithm<perf_t>(*args, res, perf_vec);
};
size_t cache_size = Global<ResourceDesc, ForSession>::Get()->thread_local_cache_max_size();
return ThreadLocalCachedCall(cache_size, Infer, args->params);
return Global<CudnnConvAlgoCache>::Get()->Remember<perf_t>(args->params, Infer);
}

template<typename perf_t, typename algo_t>
Expand Down
28 changes: 28 additions & 0 deletions oneflow/core/device/cudnn_conv_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,34 @@ struct hash<oneflow::CudnnConvParams> final {

} // namespace std

namespace oneflow {

class CudnnConvAlgoCache final {
public:
OF_DISALLOW_COPY_AND_MOVE(CudnnConvAlgoCache);
CudnnConvAlgoCache() = default;
~CudnnConvAlgoCache() = default;

template<typename perf_t>
using WorkspaceSizeAndPerfT = std::pair<size_t, perf_t>;
template<typename perf_t>
using Store = HashMap<CudnnConvParams, std::list<WorkspaceSizeAndPerfT<perf_t>>>;

template<typename perf_t>
perf_t Remember(const CudnnConvParams& params,
const std::function<perf_t(const CudnnConvParams& param)>& InferFn);

private:
Store<cudnnConvolutionFwdAlgoPerf_t> fwd_algo_store_;
std::mutex fwd_algo_store_mutex_;
Store<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_algo_store_;
std::mutex bwd_data_algo_store_mutex_;
Store<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filter_algo_store_;
std::mutex bwd_filter_algo_cache_mutex_;
};

} // namespace oneflow

#endif // WITH_CUDA

#endif // ONEFLOW_CORE_DEVICE_CUDNN_CONV_UTIL_H_
3 changes: 3 additions & 0 deletions oneflow/core/job/env_global_objects_scope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "oneflow/core/vm/virtual_machine_scope.h"
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
#include "oneflow/core/job/eager_nccl_comm_manager.h"
#include "oneflow/core/device/cudnn_conv_util.h"

namespace oneflow {

Expand Down Expand Up @@ -91,12 +92,14 @@ Maybe<void> EnvGlobalObjectsScope::Init(const EnvProto& env_proto) {
Global<EagerJobBuildAndInferCtxMgr>::New();
#ifdef WITH_CUDA
Global<EagerNcclCommMgr>::New();
Global<CudnnConvAlgoCache>::New();
#endif
return Maybe<void>::Ok();
}

EnvGlobalObjectsScope::~EnvGlobalObjectsScope() {
#ifdef WITH_CUDA
Global<CudnnConvAlgoCache>::Delete();
Global<EagerNcclCommMgr>::Delete();
#endif
Global<EagerJobBuildAndInferCtxMgr>::Delete();
Expand Down

0 comments on commit b463722

Please sign in to comment.