Skip to content

Commit

Permalink
impl sctp for RTCHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
yngrtc committed Jun 29, 2024
1 parent c458888 commit 3190812
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 75 deletions.
2 changes: 0 additions & 2 deletions rtc/src/handler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
pub mod demuxer;
pub mod dtls;
pub mod ice;
/*TODO:
pub mod sctp;
*/
160 changes: 88 additions & 72 deletions rtc/src/handler/sctp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,22 @@ use shared::error::{Error, Result};
use shared::handler::RTCHandler;
use shared::Transmit;
use std::collections::{HashMap, VecDeque};
use std::time::Instant;
use std::time::{Duration, Instant};

enum SctpMessage {
Inbound(DataChannelMessage),
Outbound(Transmit<sctp::Payload>),
}

impl RTCHandler for RTCSctpTransport {
fn handle_transmit(&mut self, msg: Transmit<RTCMessage>) -> Vec<Transmit<RTCMessage>> {
type Ein = ();
type Eout = RTCEvent;
type Rin = RTCMessage;
type Rout = RTCMessage;
type Win = RTCMessage;
type Wout = RTCMessage;

fn handle_read(&mut self, msg: Transmit<Self::Rin>) -> Result<()> {
if let RTCMessage::Dtls(DTLSMessage::Raw(dtls_message)) = msg.message {
debug!("recv sctp RAW {:?}", msg.transport.peer_addr);

Expand Down Expand Up @@ -101,7 +108,6 @@ impl RTCHandler for RTCSctpTransport {
Ok(messages)
};

let mut next_msgs = vec![];
match try_read() {
Ok(messages) => {
for message in messages {
Expand All @@ -111,7 +117,7 @@ impl RTCHandler for RTCSctpTransport {
"recv sctp data channel message {:?}",
msg.transport.peer_addr
);
next_msgs.push(Transmit {
self.routs.push_back(Transmit {
now: msg.now,
transport: msg.transport,
message: RTCMessage::Dtls(DTLSMessage::Sctp(message)),
Expand All @@ -120,7 +126,7 @@ impl RTCHandler for RTCSctpTransport {
SctpMessage::Outbound(transmit) => {
if let Payload::RawEncode(raw_data) = transmit.message {
for raw in raw_data {
self.transmits.push_back(Transmit {
self.wouts.push_back(Transmit {
now: transmit.now,
transport: transmit.transport,
message: RTCMessage::Dtls(DTLSMessage::Raw(
Expand All @@ -135,99 +141,106 @@ impl RTCHandler for RTCSctpTransport {
}
Err(err) => {
error!("try_read with error {}", err);
self.handle_error(err);
return Err(err);
}
}
next_msgs
} else {
// Bypass
debug!("bypass sctp read {:?}", msg.transport.peer_addr);
vec![msg]
self.routs.push_back(msg)
}

Ok(())
}

fn poll_transmit(&mut self, msg: Option<Transmit<RTCMessage>>) -> Option<Transmit<RTCMessage>> {
if let Some(msg) = msg {
if let RTCMessage::Dtls(DTLSMessage::Sctp(message)) = msg.message {
debug!(
"send sctp data channel message {:?}",
msg.transport.peer_addr
);
fn poll_read(&mut self) -> Option<Transmit<Self::Rout>> {
self.routs.pop_front()
}

let mut try_write = || -> Result<Vec<Transmit<Payload>>> {
let mut transmits = vec![];
fn handle_write(&mut self, msg: Transmit<Self::Win>) -> Result<()> {
if let RTCMessage::Dtls(DTLSMessage::Sctp(message)) = msg.message {
debug!(
"send sctp data channel message {:?}",
msg.transport.peer_addr
);

let max_message_size = self.max_message_size;
if message.payload.len() > max_message_size {
return Err(Error::ErrOutboundPacketTooLarge);
}
let mut try_write = || -> Result<Vec<Transmit<Payload>>> {
let mut transmits = vec![];

if let Some(conn) = self
.sctp_associations
.get_mut(&AssociationHandle(message.association_handle))
let max_message_size = self.max_message_size;
if message.payload.len() > max_message_size {
return Err(Error::ErrOutboundPacketTooLarge);
}

if let Some(conn) = self
.sctp_associations
.get_mut(&AssociationHandle(message.association_handle))
{
let mut stream = conn.stream(message.stream_id)?;
if let Some(DataChannelMessageParams {
unordered,
reliability_type,
reliability_parameter,
}) = message.params
{
let mut stream = conn.stream(message.stream_id)?;
if let Some(DataChannelMessageParams {
stream.set_reliability_params(
unordered,
reliability_type,
reliability_parameter,
}) = message.params
{
stream.set_reliability_params(
unordered,
reliability_type,
reliability_parameter,
)?;
}
stream.write_with_ppi(
&message.payload,
to_ppid(message.data_message_type, message.payload.len()),
)?;
}
stream.write_with_ppi(
&message.payload,
to_ppid(message.data_message_type, message.payload.len()),
)?;

while let Some(x) = conn.poll_transmit(msg.now) {
transmits.extend(split_transmit(x));
}
} else {
return Err(Error::ErrAssociationNotExisted);
while let Some(x) = conn.poll_transmit(msg.now) {
transmits.extend(split_transmit(x));
}
Ok(transmits)
};
match try_write() {
Ok(transmits) => {
for transmit in transmits {
if let Payload::RawEncode(raw_data) = transmit.message {
for raw in raw_data {
self.transmits.push_back(Transmit {
now: transmit.now,
transport: transmit.transport,
message: RTCMessage::Dtls(DTLSMessage::Raw(
BytesMut::from(&raw[..]),
)),
});
}
} else {
return Err(Error::ErrAssociationNotExisted);
}
Ok(transmits)
};
match try_write() {
Ok(transmits) => {
for transmit in transmits {
if let Payload::RawEncode(raw_data) = transmit.message {
for raw in raw_data {
self.wouts.push_back(Transmit {
now: transmit.now,
transport: transmit.transport,
message: RTCMessage::Dtls(DTLSMessage::Raw(BytesMut::from(
&raw[..],
))),
});
}
}
}
Err(err) => {
error!("try_write with error {}", err);
self.handle_error(err);
}
Ok(())
}
Err(err) => {
error!("try_write with error {}", err);
Err(err)
}
} else {
// Bypass
debug!("Bypass sctp write {:?}", msg.transport.peer_addr);
self.transmits.push_back(msg);
}
} else {
// Bypass
debug!("Bypass sctp write {:?}", msg.transport.peer_addr);
self.wouts.push_back(msg);
Ok(())
}
}

self.transmits.pop_front()
fn poll_write(&mut self) -> Option<Transmit<RTCMessage>> {
self.wouts.pop_front()
}

fn poll_event(&mut self) -> Option<RTCEvent> {
self.events.pop_front().map(RTCEvent::SctpTransportEvent)
}

fn handle_timeout(&mut self, now: Instant) {
fn handle_timeout(&mut self, now: Instant) -> Result<()> {
let mut try_timeout = || -> Result<Vec<Transmit<Payload>>> {
let mut transmits = vec![];

Expand Down Expand Up @@ -263,7 +276,7 @@ impl RTCHandler for RTCSctpTransport {
for transmit in transmits {
if let Payload::RawEncode(raw_data) = transmit.message {
for raw in raw_data {
self.transmits.push_back(Transmit {
self.wouts.push_back(Transmit {
now: transmit.now,
transport: transmit.transport,
message: RTCMessage::Dtls(DTLSMessage::Raw(BytesMut::from(
Expand All @@ -273,22 +286,25 @@ impl RTCHandler for RTCSctpTransport {
}
}
}
Ok(())
}
Err(err) => {
error!("try_timeout with error {}", err);
self.handle_error(err);
Err(err)
}
}
}

fn poll_timeout(&mut self, eto: &mut Instant) {
fn poll_timeout(&mut self) -> Option<Instant> {
let mut eto = Instant::now() + Duration::from_secs(86400); // 1 day
for conn in self.sctp_associations.values() {
if let Some(timeout) = conn.poll_timeout() {
if timeout < *eto {
*eto = timeout;
if timeout < eto {
eto = timeout;
}
}
}
Some(eto)
}
}

Expand Down
3 changes: 2 additions & 1 deletion rtc/src/transport/sctp_transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ pub struct RTCSctpTransport {

pub(crate) internal_buffer: Vec<u8>,
pub(crate) events: VecDeque<SctpTransportEvent>,
pub(crate) transmits: VecDeque<Transmit<RTCMessage>>,
pub(crate) routs: VecDeque<Transmit<RTCMessage>>,
pub(crate) wouts: VecDeque<Transmit<RTCMessage>>,
}

impl RTCSctpTransport {
Expand Down

0 comments on commit 3190812

Please sign in to comment.