Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions network/src/p2p/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pub struct Handler {

listener: Listener,

establish_lock: Mutex<()>,
tokens: Mutex<TokenGenerator>,

routing_table: Arc<RoutingTable>,
Expand Down Expand Up @@ -126,6 +127,7 @@ impl Handler {
socket_address,
listener: Listener::bind(&socket_address).expect("Cannot listen TCP port"),

establish_lock: Mutex::new(()),
tokens: Mutex::new(TokenGenerator::new(FIRST_CONNECTION_TOKEN, LAST_CONNECTION_TOKEN)),

routing_table,
Expand Down Expand Up @@ -168,7 +170,7 @@ impl Handler {
}
}

fn connect(&self, socket_address: &SocketAddr) -> IoHandlerResult<Option<StreamToken>> {
fn connect(&self, io: &IoContext<Message>, socket_address: &SocketAddr) -> IoHandlerResult<Option<StreamToken>> {
let ip = socket_address.ip();
if !self.filters.is_allowed(&ip) {
cinfo!(NETWORK, "P2P connection from {} is received. But it's not allowed", ip);
Expand All @@ -179,6 +181,7 @@ impl Handler {
Some(stream) => {
let remote_node_id = socket_address.into();

let _establish_lock = self.establish_lock.lock();
let local_node_id =
self.routing_table.local_node_id(&remote_node_id).ok_or(Error::General("Not handshaked"))?;
let session = self
Expand All @@ -189,7 +192,9 @@ impl Handler {
let mut tokens = self.tokens.lock();
let token = tokens.gen().ok_or(Error::General("TooManyConnections"))?;
if self.connections.connect(token, stream, local_node_id, session, socket_address, self.get_port()) {
self.routing_table.establish(socket_address);
const CONNECTION_TIMEOUT_MS: u64 = 3_000;
io.register_timer_once(token as TimerToken, CONNECTION_TIMEOUT_MS)?;
self.routing_table.set_establishing(socket_address);
Some(token)
} else {
cwarn!(NETWORK, "Cannot create connection to {}", socket_address);
Expand Down Expand Up @@ -238,10 +243,14 @@ impl Handler {
Some(ReceivedMessage::Ack {
..
}) => {
let _establish_lock = self.establish_lock.lock();
if !self.connections.establish_wait_ack_connection(stream) {
return Err(Error::InvalidStream(*stream).into())
}

let node_id = self.connections.node_id(&stream).ok_or(Error::InvalidStream(*stream))?;
self.routing_table.establish(&node_id.into_addr());
io.clear_timer(*stream as TimerToken)?;
io.message(Message::RequestNegotiation {
node_id,
})?;
Expand Down Expand Up @@ -271,6 +280,8 @@ impl Handler {
}

let remote_addr = SocketAddr::new(remote_addr.ip(), port);

let _establish_lock = self.establish_lock.lock();
let session = self
.routing_table
.unestablished_session(&remote_addr)
Expand Down Expand Up @@ -387,6 +398,16 @@ impl IoHandler<Message> for Handler {
}
Ok(())
}
FIRST_CONNECTION_TOKEN...LAST_CONNECTION_TOKEN => {
let node_id = self.connections.node_id(&token).ok_or(Error::InvalidStream(token))?;
let address = node_id.into_addr();

if !self.routing_table.reset_session(&address) {
return Err(Error::General("Failed to find session").into())
}
self.connections.shutdown(&address)?;
Ok(())
}
_ => unreachable!(),
}
}
Expand All @@ -403,7 +424,7 @@ impl IoHandler<Message> for Handler {
}

ctrace!(NETWORK, "Connecting to {}", socket_address);
let token = self.connect(&socket_address)?.ok_or(Error::General("Cannot create connection"))?;
let token = self.connect(io, &socket_address)?.ok_or(Error::General("Cannot create connection"))?;
cinfo!(NETWORK, "New connection to {}({})", socket_address, token);
io.register_stream(token)?;
Ok(())
Expand Down Expand Up @@ -474,7 +495,7 @@ impl IoHandler<Message> for Handler {
let was_established = self.connections.is_established(&stream);
self.connections.set_disconnecting(&stream);
let node_id = self.connections.node_id(&stream).ok_or(Error::InvalidStream(stream))?;
self.routing_table.remove_node(node_id.into_addr());
self.routing_table.remove_node_on_shutdown(node_id.into_addr());
if was_established {
self.client.on_node_removed(&node_id);
}
Expand Down
62 changes: 57 additions & 5 deletions network/src/routing_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum SecretOrigin {
}

// Intermediate : Middle state in changing state, ex) A state -> Intermediate -> B state
// Discovery flow : Candidate -> Alive -> KeyPairShared -> SecretShared -> TemporaryNonceShared -> SessionShared -> Established
// Discovery flow : Candidate -> Alive -> KeyPairShared -> SecretShared -> TemporaryNonceShared -> SessionShared -> (Establishing) -> Established
// Offline secret exchange flow : SecretpreImported -> TemporaryNonceShared -> SessionShared -> Established
#[derive(Clone, Debug, PartialEq)]
enum State {
Expand All @@ -45,6 +45,7 @@ enum State {
SecretShared(Secret),
TemporaryNonceShared(Secret, Nonce, SecretOrigin),
SessionShared(Session),
Establishing(Session),
Established(NodeId),
Banned,
}
Expand Down Expand Up @@ -153,6 +154,14 @@ impl RoutingTable {
}

pub fn remove_node(&self, addr: SocketAddr) -> bool {
self.remove_node_internal(addr, false)
}

pub fn remove_node_on_shutdown(&self, addr: SocketAddr) -> bool {
self.remove_node_internal(addr, true)
}

fn remove_node_internal(&self, addr: SocketAddr, on_shutdown: bool) -> bool {
let mut entries = self.entries.write();
let mut remote_to_local_node_ids = self.remote_to_local_node_ids.write();

Expand All @@ -166,6 +175,12 @@ impl RoutingTable {
remote_to_local_node_ids.remove(&remote_node_id);
return false
}
State::SessionShared(_) => {
entry.set(old_state);
if on_shutdown {
return false
}
}
_ => {
entry.set(old_state);
}
Expand Down Expand Up @@ -393,20 +408,57 @@ impl RoutingTable {
false
}

pub fn set_establishing(&self, remote_address: &SocketAddr) -> bool {
let entries = self.entries.read();
let remote_node_id = remote_address.into();
if let Some(entry) = entries.get(&remote_node_id) {
let entry = entry.lock();
let old_state = entry.replace(State::Intermediate);
if let State::SessionShared(session) = old_state {
entry.set(State::Establishing(session));
ctrace!(ROUTING_TABLE, "Connection to {} set establishing", remote_address);
return true
}
entry.set(old_state);
}
ctrace!(ROUTING_TABLE, "Cannot set connection to {} as establishing", remote_address);
false
}

pub fn establish(&self, remote_address: &SocketAddr) -> bool {
let entries = self.entries.read();
let remote_node_id = remote_address.into();
if let Some(entry) = entries.get(&remote_node_id) {
let entry = entry.lock();
let old_state = entry.replace(State::Intermediate);
if let State::SessionShared(_) = old_state {
entry.set(State::Established(remote_node_id));
ctrace!(ROUTING_TABLE, "Connection to {} established", remote_address);
match old_state {
State::SessionShared(_) | State::Establishing(_) => {
entry.set(State::Established(remote_node_id));
ctrace!(ROUTING_TABLE, "Connection to {} is established", remote_address);
return true
}
_ => {}
}
entry.set(old_state);
}
ctrace!(ROUTING_TABLE, "Cannot establish connection to {}", remote_address);
false
}

pub fn reset_session(&self, remote_address: &SocketAddr) -> bool {
let entries = self.entries.read();
let remote_node_id = remote_address.into();
if let Some(entry) = entries.get(&remote_node_id) {
let entry = entry.lock();
let old_state = entry.replace(State::Intermediate);
if let State::Establishing(session) = old_state {
entry.set(State::SessionShared(session));
ctrace!(ROUTING_TABLE, "Connection to {} is ready to reconnect", remote_address);
return true
}
entry.set(old_state);
}
ctrace!(ROUTING_TABLE, "Cannot establish connection to {} established", remote_address);
ctrace!(ROUTING_TABLE, "Cannot reset connection to {}, because it's not establishing", remote_address);
false
}

Expand Down