Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 121 additions & 78 deletions crates/stdlib/src/ssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ mod _ssl {

// Import error types used in this module (others are exposed via pymodule(with(...)))
use super::error::{
PySSLEOFError, PySSLError, create_ssl_want_read_error, create_ssl_want_write_error,
PySSLError, create_ssl_eof_error, create_ssl_want_read_error, create_ssl_want_write_error,
};
use alloc::sync::Arc;
use core::{
Expand Down Expand Up @@ -1903,6 +1903,7 @@ mod _ssl {
client_hello_buffer: PyMutex::new(None),
shutdown_state: PyMutex::new(ShutdownState::NotStarted),
pending_tls_output: PyMutex::new(Vec::new()),
write_buffered_len: PyMutex::new(0),
deferred_cert_error: Arc::new(ParkingRwLock::new(None)),
};

Expand Down Expand Up @@ -1974,6 +1975,7 @@ mod _ssl {
client_hello_buffer: PyMutex::new(None),
shutdown_state: PyMutex::new(ShutdownState::NotStarted),
pending_tls_output: PyMutex::new(Vec::new()),
write_buffered_len: PyMutex::new(0),
deferred_cert_error: Arc::new(ParkingRwLock::new(None)),
};

Expand Down Expand Up @@ -2345,6 +2347,10 @@ mod _ssl {
// but the socket cannot accept all the data immediately
#[pytraverse(skip)]
pub(crate) pending_tls_output: PyMutex<Vec<u8>>,
// Tracks bytes already buffered in rustls for the current write operation
// Prevents duplicate writes when retrying after WantWrite/WantRead
#[pytraverse(skip)]
pub(crate) write_buffered_len: PyMutex<usize>,
// Deferred client certificate verification error (for TLS 1.3)
// Stores error message if client cert verification failed during handshake
// Error is raised on first I/O operation after handshake
Expand Down Expand Up @@ -2604,6 +2610,36 @@ mod _ssl {
Ok(timed_out)
}

// Internal implementation with explicit timeout override
pub(crate) fn sock_wait_for_io_with_timeout(
&self,
kind: SelectKind,
timeout: Option<std::time::Duration>,
vm: &VirtualMachine,
) -> PyResult<bool> {
if self.is_bio_mode() {
// BIO mode doesn't use select
return Ok(false);
}

if let Some(t) = timeout
&& t.is_zero()
{
// Non-blocking mode - don't use select
return Ok(false);
}

let py_socket: PyRef<PySocket> = self.sock.clone().try_into_value(vm)?;
let socket = py_socket
.sock()
.map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?;

let timed_out = sock_select(&socket, kind, timeout)
.map_err(|e| vm.new_os_error(format!("select failed: {e}")))?;

Ok(timed_out)
}

// SNI (Server Name Indication) Helper Methods:
// These methods support the server-side handshake SNI callback mechanism

Expand Down Expand Up @@ -2783,6 +2819,7 @@ mod _ssl {
let is_non_blocking = socket_timeout.map(|t| t.is_zero()).unwrap_or(false);

let mut sent_total = 0;

while sent_total < pending.len() {
// Calculate timeout: use deadline if provided, otherwise use socket timeout
let timeout_to_use = if let Some(dl) = deadline {
Expand Down Expand Up @@ -2810,6 +2847,9 @@ mod _ssl {
if timed_out {
// Keep unsent data in pending buffer
*pending = pending[sent_total..].to_vec();
if is_non_blocking {
return Err(create_ssl_want_write_error(vm).upcast());
}
return Err(
timeout_error_msg(vm, "The write operation timed out".to_string()).upcast(),
);
Expand All @@ -2824,6 +2864,7 @@ mod _ssl {
*pending = pending[sent_total..].to_vec();
return Err(create_ssl_want_write_error(vm).upcast());
}
// Socket said ready but sent 0 bytes - retry
continue;
}
sent_total += sent;
Expand Down Expand Up @@ -2916,6 +2957,9 @@ mod _ssl {
pub(crate) fn blocking_flush_all_pending(&self, vm: &VirtualMachine) -> PyResult<()> {
// Get socket timeout to respect during flush
let timeout = self.get_socket_timeout(vm)?;
if timeout.map(|t| t.is_zero()).unwrap_or(false) {
return self.flush_pending_tls_output(vm, None);
}

loop {
let pending_data = {
Expand Down Expand Up @@ -2948,8 +2992,7 @@ mod _ssl {
let mut pending = self.pending_tls_output.lock();
pending.drain(..sent);
}
// If sent == 0, socket wasn't ready despite select() saying so
// Continue loop to retry - this avoids infinite loops
// If sent == 0, loop will retry with sock_select
}
Err(e) => {
if is_blocking_io_error(&e, vm) {
Expand Down Expand Up @@ -3515,16 +3558,60 @@ mod _ssl {
return_data(buf, &buffer, vm)
}
Err(crate::ssl::compat::SslError::Eof) => {
// If plaintext is still buffered, return it before EOF.
let pending = {
let mut conn_guard = self.connection.lock();
let conn = match conn_guard.as_mut() {
Some(conn) => conn,
None => return Err(create_ssl_eof_error(vm).upcast()),
};
use std::io::BufRead;
let mut reader = conn.reader();
reader.fill_buf().map(|buf| buf.len()).unwrap_or(0)
};
if pending > 0 {
let mut buf = vec![0u8; pending.min(len)];
let read_retry = {
let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
};
if let Ok(n) = read_retry {
buf.truncate(n);
return return_data(buf, &buffer, vm);
}
}
// EOF occurred in violation of protocol (unexpected closure)
Err(vm
.new_os_subtype_error(
PySSLEOFError::class(&vm.ctx).to_owned(),
None,
"EOF occurred in violation of protocol",
)
.upcast())
Err(create_ssl_eof_error(vm).upcast())
}
Err(crate::ssl::compat::SslError::ZeroReturn) => {
// If plaintext is still buffered, return it before clean EOF.
let pending = {
let mut conn_guard = self.connection.lock();
let conn = match conn_guard.as_mut() {
Some(conn) => conn,
None => return return_data(vec![], &buffer, vm),
};
use std::io::BufRead;
let mut reader = conn.reader();
reader.fill_buf().map(|buf| buf.len()).unwrap_or(0)
};
if pending > 0 {
let mut buf = vec![0u8; pending.min(len)];
let read_retry = {
let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
};
if let Ok(n) = read_retry {
buf.truncate(n);
return return_data(buf, &buffer, vm);
}
}
// Clean closure with close_notify - return empty data
return_data(vec![], &buffer, vm)
}
Expand Down Expand Up @@ -3580,21 +3667,17 @@ mod _ssl {
let data_bytes = data.borrow_buf();
let data_len = data_bytes.len();

// return 0 immediately for empty write
if data_len == 0 {
return Ok(0);
}

// Ensure handshake is done - if not, complete it first
// This matches OpenSSL behavior where SSL_write() auto-completes handshake
// Ensure handshake is done (SSL_write auto-completes handshake)
if !*self.handshake_done.lock() {
self.do_handshake(vm)?;
}

// Check if connection has been shut down
// After unwrap()/shutdown(), write operations should fail with SSLError
let shutdown_state = *self.shutdown_state.lock();
if shutdown_state != ShutdownState::NotStarted {
// Check shutdown state
if *self.shutdown_state.lock() != ShutdownState::NotStarted {
return Err(vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
Expand All @@ -3604,76 +3687,32 @@ mod _ssl {
.upcast());
}

{
// Call ssl_write (matches CPython's SSL_write_ex loop)
let result = {
let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;

let is_bio = self.is_bio_mode();
let data: &[u8] = data_bytes.as_ref();
crate::ssl::compat::ssl_write(conn, data_bytes.as_ref(), self, vm)
};

// CRITICAL: Flush any pending TLS data before writing new data
// This ensures TLS 1.3 Finished message reaches server before application data
// Without this, server may not be ready to process our data
if !is_bio {
self.flush_pending_tls_output(vm, None)?;
match result {
Ok(n) => {
self.check_deferred_cert_error(vm)?;
Ok(n)
}

// Write data in chunks to avoid filling the internal TLS buffer
// rustls has a limited internal buffer, so we need to flush periodically
const CHUNK_SIZE: usize = 16384; // 16KB chunks (typical TLS record size)
let mut written = 0;

while written < data.len() {
let chunk_end = core::cmp::min(written + CHUNK_SIZE, data.len());
let chunk = &data[written..chunk_end];

// Write chunk to TLS layer
{
let mut writer = conn.writer();
use std::io::Write;
writer
.write_all(chunk)
.map_err(|e| vm.new_os_error(format!("Write failed: {e}")))?;
// Flush to ensure data is converted to TLS records
writer
.flush()
.map_err(|e| vm.new_os_error(format!("Flush failed: {e}")))?;
}

written = chunk_end;

// Flush TLS data to socket after each chunk
if conn.wants_write() {
if is_bio {
self.write_pending_tls(conn, vm)?;
} else {
// Socket mode: flush all pending TLS data
// First, try to send any previously pending data
self.flush_pending_tls_output(vm, None)?;

while conn.wants_write() {
let mut buf = Vec::new();
conn.write_tls(&mut buf).map_err(|e| {
vm.new_os_error(format!("TLS write failed: {e}"))
})?;

if !buf.is_empty() {
// Try to send TLS data, saving unsent bytes to pending buffer
self.send_tls_output(buf, vm)?;
}
}
}
}
Err(crate::ssl::compat::SslError::WantRead) => {
Err(create_ssl_want_read_error(vm).upcast())
}
Err(crate::ssl::compat::SslError::WantWrite) => {
Err(create_ssl_want_write_error(vm).upcast())
}
Err(crate::ssl::compat::SslError::Timeout(msg)) => {
Err(timeout_error_msg(vm, msg).upcast())
}
Err(e) => Err(e.into_py_err(vm)),
}

// Check for deferred certificate verification errors (TLS 1.3)
// Must be checked AFTER write completes, as the error may be set during I/O
self.check_deferred_cert_error(vm)?;

Ok(data_len)
}

#[pymethod]
Expand Down Expand Up @@ -4013,6 +4052,10 @@ mod _ssl {

// Write close_notify to outgoing buffer/BIO
self.write_pending_tls(conn, vm)?;
// Ensure close_notify and any pending TLS data are flushed
if !is_bio {
self.flush_pending_tls_output(vm, None)?;
}

// Update state
*self.shutdown_state.lock() = ShutdownState::SentCloseNotify;
Expand Down
Loading
Loading