Skip to content

Commit

Permalink
[backend] implement pthread backend (#46)
Browse files Browse the repository at this point in the history
* [backend] init pthread pool

* [pthread] make it compilable

* [pthread] pass callback f ptr

* [pthread] sync api with other backends

* [pthread] fetch environ

* [pthread] implement sync

* [pthread] fix

* [pthread] debug hang

* [pthread] use c++ 3rd lib thread pool

* [pthread] pass test

* [chore] remove unintended commit

* [chore] refine

* [chore] license
  • Loading branch information
botbw authored Sep 24, 2024
1 parent f165590 commit 51ed242
Show file tree
Hide file tree
Showing 8 changed files with 1,351 additions and 6 deletions.
27 changes: 23 additions & 4 deletions csrc/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#ifndef DISABLE_AIO
#include "aio.h"
#endif
#ifndef DISABLE_PTHREAD
#include "pthread_backend.h"
#endif

std::unordered_set<std::string> get_backends()
{
Expand All @@ -20,6 +23,9 @@ std::unordered_set<std::string> get_backends()
#endif
#ifndef DISABLE_AIO
backends.insert("aio");
#endif
#ifndef DISABLE_PTHREAD
backends.insert("pthread");
#endif
return backends;
}
Expand All @@ -35,18 +41,27 @@ void probe_asyncio(const std::string &backend)
try
{
std::unique_ptr<AsyncIO> aio;
if (backend == "uring")
if (backend == "uring") {
#ifndef DISABLE_URING
aio.reset(new UringAsyncIO(2));
#else
throw std::runtime_error("backend is not installed\n");
throw std::runtime_error("backend uring is not installed\n");
#endif
else
} else if (backend == "aio") {
#ifndef DISABLE_AIO
aio.reset(new AIOAsyncIO(2));
#else
throw std::runtime_error("backend is not installed\n");
throw std::runtime_error("backend aio is not installed\n");
#endif
} else if (backend == "pthread") {
#ifndef DISABLE_PTHREAD
aio.reset(new PthreadAsyncIO(2));
#else
throw std::runtime_error("backend pthread is not installed\n");
#endif
} else {
throw std::runtime_error("unknown backend");
}

int fd = fileno(fp);
const int n_loop = 5, n_len = 18;
Expand Down Expand Up @@ -120,6 +135,10 @@ AsyncIO *create_asyncio(unsigned int n_entries, const std::string &backend)
#ifndef DISABLE_AIO
if (backend == "aio")
return new AIOAsyncIO(n_entries);
#endif
#ifndef DISABLE_PTHREAD
if (backend == "pthread")
return new PthreadAsyncIO(n_entries);
#endif
throw std::runtime_error("Unsupported backend: " + backend);
}
79 changes: 79 additions & 0 deletions csrc/pthread_backend.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include "pthread_backend.h"

void PthreadAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) {
auto fut = this->pool.submit_task(
[fd, buffer, n_bytes, offset] {
return pwrite(fd, buffer, n_bytes, offset);
}
);
this->write_fut.push_back(std::make_tuple(std::move(fut), callback));
}

void PthreadAsyncIO::writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) {
auto fut = this->pool.submit_task(
[fd, iov, iovcnt, offset] {
return pwritev(fd, iov, iovcnt, offset);
}
);
this->write_fut.push_back(std::make_tuple(std::move(fut), callback));
}

void PthreadAsyncIO::read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) {
auto fut = this->pool.submit_task(
[fd, buffer, n_bytes, offset] {
return pread(fd, buffer, n_bytes, offset);
}
);
this->read_fut.push_back(std::make_tuple(std::move(fut), callback));
}

void PthreadAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) {
auto fut = this->pool.submit_task(
[fd, iov, iovcnt, offset] {
return preadv(fd, iov, iovcnt, offset);
}
);
this->read_fut.push_back(std::make_tuple(std::move(fut), callback));
}

void PthreadAsyncIO::get_event(WaitType wt) {
if (wt == NOWAIT) return;
this->sync_write_events();
this->sync_read_events();
}

void PthreadAsyncIO::sync_write_events() {
while (this->write_fut.size() > 0) {
auto front = std::move(this->write_fut.front());
this->write_fut.pop_front();

auto fut(std::move(std::get<0>(front)));
fut.wait();

auto callback = std::get<1>(front);
if (callback != nullptr) {
callback();
}
}
}

void PthreadAsyncIO::sync_read_events() {
while (this->read_fut.size() > 0) {
auto front = std::move(this->read_fut.front());
this->read_fut.pop_front();

auto fut(std::move(std::get<0>(front)));
fut.wait();

auto callback = std::get<1>(front);
if (callback != nullptr) {
callback();
}
}
}

void PthreadAsyncIO::synchronize() {
this->get_event(WAIT);
}

void PthreadAsyncIO::register_file(int fd) {}
2 changes: 2 additions & 0 deletions csrc/py_api.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include "offload.h"
#include "async_file_io.h"
#include "backend.h"
Expand Down
2 changes: 2 additions & 0 deletions include/offload.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include "asyncio.h"
#include <ATen/ATen.h>

#include "space_mgr.h"
#ifndef DISABLE_URING
#include "uring.h"
Expand Down
41 changes: 41 additions & 0 deletions include/pthread_backend.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once

#include <stdexcept>
#include <sys/io.h>
#include <sys/uio.h>
#include <unistd.h>
#include <cstdlib>
#include <future>
#include <queue>
#include <tuple>
#include <functional>

#include "asyncio.h"
#include "threadpool.hpp"


class PthreadAsyncIO : public AsyncIO
{
private:
BS::thread_pool pool;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> write_fut;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> read_fut;

public:
PthreadAsyncIO(unsigned int n_entries)
: pool(n_entries) {}

~PthreadAsyncIO() {}

void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
void read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);

void get_event(WaitType wt);
void sync_write_events();
void sync_read_events();
void synchronize();

void register_file(int fd);
};
Loading

0 comments on commit 51ed242

Please sign in to comment.