diff --git a/Cargo.lock b/Cargo.lock index 69415ec..4b8eee4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -164,10 +164,10 @@ name = "blondie" version = "0.5.2" dependencies = [ "clap", - "crossbeam-skiplist", "inferno", "object 0.30.4", "pdb-addr2line", + "rustc-hash", "symsrv", "tokio", "windows", @@ -325,25 +325,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "crossbeam-epoch" -version = "0.9.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-skiplist" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df29de440c58ca2cc6e587ec3d22347551a32435fbde9d2bff64e78a9ffa151b" -dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", -] - [[package]] name = "crossbeam-utils" version = "0.8.20" @@ -1048,6 +1029,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustls" version = "0.21.12" diff --git a/Cargo.toml b/Cargo.toml index 15255f3..14b91bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ windows = { version = "0.44.0", features = [ "Win32_Storage_FileSystem", "Win32_System_SystemInformation" ] } -crossbeam-skiplist = "0.1.3" +rustc-hash = "1.1.0" # Only for the binary crate inferno = { version = "0.11", optional = true } clap = { version = "4.0.26", optional = true, features = ["derive"] } diff --git a/src/lib.rs b/src/lib.rs index d56942c..76efc66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ use std::os::windows::ffi::OsStringExt; use std::path::PathBuf; use std::ptr::{addr_of, addr_of_mut}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::mpsc::{Receiver, Sender}; use std::sync::{Arc, Mutex, Weak}; use std::thread::JoinHandle; @@ -26,6 +27,7 @@ use crossbeam_skiplist::{SkipMap, SkipSet}; use object::Object; use pdb_addr2line::pdb::PDB; use pdb_addr2line::ContextPdbData; +use rustc_hash::FxHashMap; use windows::core::{GUID, PCSTR, PSTR}; use windows::Win32::Foundation::{ CloseHandle, ERROR_SUCCESS, ERROR_WMI_INSTANCE_NOT_FOUND, HANDLE, INVALID_HANDLE_VALUE, @@ -59,9 +61,9 @@ struct TraceContext { show_kernel_samples: bool, /// map[array_of_stacktrace_addrs] = sample_count - stack_counts_hashmap: SkipMap<[u64; MAX_STACK_DEPTH], AtomicU64>, + stack_counts_hashmap: FxHashMap<[u64; MAX_STACK_DEPTH], AtomicU64>, /// (image_path, image_base, image_size) - image_paths: SkipSet<(OsString, u64, u64)>, + image_paths: Vec<(OsString, u64, u64)>, } impl TraceContext { /// The Context takes ownership of the handle. @@ -101,25 +103,96 @@ impl Drop for TraceContext { } } +struct Context { + /// Keys are process IDs. + traces: FxHashMap>, + /// Receive new processes to subscribe to tracing. + subscribe_recv: Receiver>, + /// Set to true once the trace starts running, deallocated afterward. + trace_running: Weak, +} + /// Global state context provided to `event_record_callback`. -struct GlobalContext { - /// Starts out as false and is set to and remains true once the trace starts running. - trace_running: AtomicBool, - /// Keys are process ID. - traces: SkipMap, - /// Background processing thread. - processing_thread: Option>>, - /// Logfile allocation. +struct Global { + /// Send new processes to subscribe to tracing. + subscriber_send: Sender>, + /// Allocation for [`Context`]. + context: *const Context, + /// Allocation for event trace props, need deallocation after. + event_trace_props: Box, + /// Allocation for Logfile. log: EVENT_TRACE_LOGFILEA, - /// Must be `!Unpin` as because `log` contains a pointer to self. - _pin: PhantomPinned, } -unsafe impl Send for GlobalContext {} -unsafe impl Sync for GlobalContext {} + +impl Global { + fn start(self: &Arc, trace_context: TraceContext) -> TraceGuard { + let trace_context = Arc::new(trace_context); + self.traces + .insert((*trace_context).target_proc_pid, Arc::clone(&trace_context)); + TraceGuard { + global: Arc::clone(&self), + trace_context, + } + } +} + +impl Drop for Global { + fn drop(&mut self) { + let ret = unsafe { + // This unblocks ProcessTrace + ControlTraceA( + ::default(), + KERNEL_LOGGER_NAMEA, + addr_of_mut!(*self.event_trace_props).cast(), + EVENT_TRACE_CONTROL_STOP, + ) + }; + if ret != ERROR_SUCCESS { + eprintln!( + "Error dropping GlobalContext: {:?}", + get_last_error("ControlTraceA STOP ProcessTrace") + ); + } + } +} + +#[must_use] +struct TraceGuard { + global: Arc, + trace_context: Arc, +} +impl TraceGuard { + fn stop(mut self) -> TraceContext { + println!("BEFORE {}.", Arc::strong_count(&self.trace_context)); + + let entry = self + .global + .traces + .remove(&self.trace_context.target_proc_pid) + .expect("Cannot stop multiple times."); + println!("REMOVE {}", entry.remove()); + drop(entry); + drop(self.global); + + println!( + "TRYING TO TAKE FROM ARC {}.", + Arc::strong_count(&self.trace_context) + ); + loop { + match Arc::try_unwrap(self.trace_context) { + Ok(trace_context) => break trace_context, + Err(still_shared) => { + self.trace_context = still_shared; + } + } + std::hint::spin_loop(); + } + } +} /// Gets the global context. Begins tracing if not already running. -fn run_global_context() -> Result> { - static GLOBAL_CONTEXT: Mutex> = Mutex::new(Weak::new()); +fn get_global_context() -> Result> { + static GLOBAL_CONTEXT: Mutex> = Mutex::new(Weak::new()); let mut unlocked = GLOBAL_CONTEXT.lock().unwrap(); if let Some(global_context) = unlocked.upgrade() { @@ -175,27 +248,11 @@ fn run_global_context() -> Result> { // Events are delivered when the buffers are flushed (https://docs.microsoft.com/en-us/windows/win32/etw/logging-mode-constants) // We also use Image_Load events to know which dlls to load debug information from for symbol resolution // Which is enabled by the EVENT_TRACE_FLAG_IMAGE_LOAD flag - const KERNEL_LOGGER_NAMEA_LEN: usize = unsafe { - let mut ptr = KERNEL_LOGGER_NAMEA.0; - let mut len = 0; - while *ptr != 0 { - len += 1; - ptr = ptr.add(1); - } - len - }; const PROPS_SIZE: usize = size_of::() + KERNEL_LOGGER_NAMEA_LEN + 1; - #[derive(Clone)] - #[repr(C)] - #[allow(non_camel_case_types)] - struct EVENT_TRACE_PROPERTIES_WITH_STRING { - data: EVENT_TRACE_PROPERTIES, - s: [u8; KERNEL_LOGGER_NAMEA_LEN + 1], - } - let mut event_trace_props = EVENT_TRACE_PROPERTIES_WITH_STRING { + let mut event_trace_props = Box::new(EVENT_TRACE_PROPERTIES_WITH_STRING { data: EVENT_TRACE_PROPERTIES::default(), s: [0u8; KERNEL_LOGGER_NAMEA_LEN + 1], - }; + }); event_trace_props.data.EnableFlags = EVENT_TRACE_FLAG_PROFILE | EVENT_TRACE_FLAG_IMAGE_LOAD; event_trace_props.data.LogFileMode = EVENT_TRACE_REAL_TIME_MODE; event_trace_props.data.Wnode.BufferSize = PROPS_SIZE as u32; @@ -212,18 +269,18 @@ fn run_global_context() -> Result> { .s .copy_from_slice(&kernel_logger_name_with_nul[..]); - let kernel_logger_name_with_nul_pcstr = PCSTR(kernel_logger_name_with_nul.as_ptr()); + // let kernel_logger_name_with_nul_pcstr = PCSTR(kernel_logger_name_with_nul.as_ptr()); // Stop an existing session with the kernel logger, if it exists // We use a copy of `event_trace_props` since ControlTrace overwrites it { - let mut event_trace_props_copy = event_trace_props.clone(); + let mut event_trace_props_copy = (*event_trace_props).clone(); // SAFETY: controlled input. // https://learn.microsoft.com/en-us/windows/win32/api/evntrace/nf-evntrace-controltracea let control_stop_retcode = unsafe { ControlTraceA( None, - kernel_logger_name_with_nul_pcstr, - addr_of_mut!(event_trace_props_copy) as *mut _, + KERNEL_LOGGER_NAMEA, + addr_of_mut!(event_trace_props_copy).cast(), EVENT_TRACE_CONTROL_STOP, ) }; @@ -242,8 +299,8 @@ fn run_global_context() -> Result> { let start_retcode = unsafe { StartTraceA( addr_of_mut!(trace_session_handle), - kernel_logger_name_with_nul_pcstr, - addr_of_mut!(event_trace_props) as *mut _, + KERNEL_LOGGER_NAMEA, + addr_of_mut!(*event_trace_props).cast(), ) }; if start_retcode != ERROR_SUCCESS { @@ -276,10 +333,10 @@ fn run_global_context() -> Result> { } } - let mut context_arc = Arc::new(GlobalContext { + let mut context_arc = Arc::new(Global { trace_running: AtomicBool::new(false), traces: SkipMap::new(), - processing_thread: None, // Set later. + event_trace_props, log: EVENT_TRACE_LOGFILEA::default(), _pin: PhantomPinned, }); @@ -297,18 +354,19 @@ fn run_global_context() -> Result> { unsafe extern "system" fn event_record_callback(record: *mut EVENT_RECORD) { let provider_guid_data1 = (*record).EventHeader.ProviderId.data1; let event_opcode = (*record).EventHeader.EventDescriptor.Opcode; - let context = &mut *(*record).UserContext.cast::(); + let context = &mut *(*record).UserContext.cast::(); context.trace_running.store(true, Ordering::Relaxed); const EVENT_TRACE_TYPE_LOAD: u8 = 10; if event_opcode == EVENT_TRACE_TYPE_LOAD { let event = (*record).UserData.cast::().read_unaligned(); + // TODO: handle PID reuse??? let Some(trace_context) = context.traces.get(&event.ProcessId) else { // Ignore dlls for other processes return; }; - let trace_context = &mut **trace_context.value(); + let trace_context = trace_context.value(); let filename_p = (*record) .UserData @@ -340,11 +398,12 @@ fn run_global_context() -> Result> { let proc = ud_p.cast::().offset(2).read_unaligned(); let _thread = ud_p.cast::().offset(3).read_unaligned(); + // TODO: handle PID reuse??? let Some(trace_context) = context.traces.get(&proc) else { // Ignore stackwalks for other processes return; }; - let trace_context = &mut **trace_context.value(); + let trace_context = trace_context.value(); let stack_depth_32 = ((*record).UserDataLength - 16) / 4; let stack_depth_64 = stack_depth_32 / 2; @@ -432,10 +491,10 @@ fn run_global_context() -> Result> { return Err(get_last_error("OpenTraceA processing")); } - context.processing_thread = Some(std::thread::spawn(move || { + let _ = std::thread::spawn(move || { unsafe { SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL); - // This blocks + // This blocks until `EVENT_TRACE_CONTROL_STOP` on `GlobalContext::drop`. ProcessTrace(&[trace_processing_handle], None, None) }; @@ -444,7 +503,7 @@ fn run_global_context() -> Result> { return Err(get_last_error("Error closing trace")); } Ok(()) - })); + }); // Wait until we know for sure the trace is running while !context.trace_running.load(Ordering::Relaxed) { @@ -456,6 +515,24 @@ fn run_global_context() -> Result> { Ok(context_arc) } +const KERNEL_LOGGER_NAMEA_LEN: usize = unsafe { + let mut ptr = KERNEL_LOGGER_NAMEA.0; + let mut len = 0; + while *ptr != 0 { + len += 1; + ptr = ptr.add(1); + } + len +}; + +#[derive(Clone)] +#[repr(C)] +#[allow(non_camel_case_types)] +struct EVENT_TRACE_PROPERTIES_WITH_STRING { + data: EVENT_TRACE_PROPERTIES, + s: [u8; KERNEL_LOGGER_NAMEA_LEN + 1], +} + /// The main tracing logic. Traces the process with the given `target_process_id`. /// /// # Safety @@ -467,9 +544,9 @@ unsafe fn trace_from_process_id( kernel_stacks: bool, ) -> Result { let target_proc_handle = util::handle_from_process_id(target_process_id)?; - let mut context = + let trace_context = unsafe { TraceContext::new(target_proc_handle, target_process_id, kernel_stacks)? }; - // TODO: Do we need to Box the context? + let trace_guard = get_global_context()?.start(trace_context); // Resume the suspended process if is_suspended { @@ -489,32 +566,19 @@ unsafe fn trace_from_process_id( #[allow(non_snake_case)] let NtResumeProcess: extern "system" fn(isize) -> i32 = std::mem::transmute(NtResumeProcess); - NtResumeProcess(context.target_process_handle.0); + NtResumeProcess(target_proc_handle.0); } // Wait for it to end util::wait_for_process_by_handle(target_proc_handle)?; - // This unblocks ProcessTrace - let ret = ControlTraceA( - ::default(), - PCSTR(kernel_logger_name_with_nul.as_ptr()), - addr_of_mut!(event_trace_props) as *mut _, - EVENT_TRACE_CONTROL_STOP, - ); - if ret != ERROR_SUCCESS { - return Err(get_last_error("ControlTraceA STOP ProcessTrace")); - } - // Block until processing thread is done - // (Safeguard to make sure we don't deallocate the context before the other thread finishes using it) - processing_thread - .join() - .map_err(|_err_any| Error::UnknownError)??; - - if context.show_kernel_samples { - let kernel_module_paths = util::list_kernel_modules(); - context.image_paths.extend(kernel_module_paths); + let trace_context = trace_guard.stop(); + if trace_context.show_kernel_samples { + todo!(); + // let kernel_module_paths = util::list_kernel_modules(); + // trace_context.image_paths.extend(kernel_module_paths); } + Ok(trace_context) } /// The sampled results from a process execution @@ -560,8 +624,10 @@ pub fn trace_command( /// /// You can get them using [`CollectionResults::iter_callstacks`] pub struct CallStack<'a> { - stack: &'a [u64; MAX_STACK_DEPTH], + stack: [u64; MAX_STACK_DEPTH], sample_count: u64, + // TODO + _phantom: std::marker::PhantomData<&'a ()>, } /// An address from a callstack @@ -582,8 +648,10 @@ type PdbDb<'a, 'b> = std::collections::BTreeMap)>; /// Returns Vec<(image_base, image_size, image_name, addr2line pdb context)> -fn find_pdbs(images: &[(OsString, u64, u64)]) -> Vec<(u64, u64, OsString, OwnedPdb)> { - let mut pdb_db = Vec::with_capacity(images.len()); +fn find_pdbs( + images: impl IntoIterator, +) -> Vec<(u64, u64, OsString, OwnedPdb)> { + let mut pdb_db = Vec::new(); fn owned_pdb(pdb_file_bytes: Vec) -> Option { let pdb = PDB::open(std::io::Cursor::new(pdb_file_bytes)).ok()?; @@ -645,7 +713,7 @@ fn find_pdbs(images: &[(OsString, u64, u64)]) -> Vec<(u64, u64, OsString, OwnedP _ => continue, }; - pdb_db.push((*image_base, *image_size, image_name.to_owned(), pdb_ctx)); + pdb_db.push((image_base, image_size, image_name.to_owned(), pdb_ctx)); } else if use_symsrv { let pdb_filename = match pdb_path.file_name() { Some(x) => x, @@ -678,7 +746,7 @@ fn find_pdbs(images: &[(OsString, u64, u64)]) -> Vec<(u64, u64, OsString, OwnedP Some(x) => x, _ => continue, }; - pdb_db.push((*image_base, *image_size, image_name.to_owned(), pdb_ctx)); + pdb_db.push((image_base, image_size, image_name.to_owned(), pdb_ctx)); } } } @@ -706,7 +774,7 @@ impl<'a> CallStack<'a> { } let displacement = 0u64; let mut symbol_names_storage = reuse_vec(std::mem::take(v)); - for &addr in self.stack { + for addr in self.stack { if addr == 0 { *v = symbol_names_storage; return Ok(()); @@ -743,17 +811,19 @@ impl<'a> CallStack<'a> { Ok(()) } } + impl CollectionResults { /// Iterate the distinct callstacks sampled in this execution pub fn iter_callstacks(&self) -> impl std::iter::Iterator> { self.0.stack_counts_hashmap.iter().map(|x| CallStack { - stack: x.0, - sample_count: *x.1, + stack: x.key().clone(), + sample_count: x.value().load(Ordering::Relaxed), + _phantom: std::marker::PhantomData, }) } /// Resolve call stack symbols and write a dtrace-like sampling report to `w` pub fn write_dtrace(&self, mut w: W) -> Result<()> { - let pdbs = find_pdbs(&self.0.image_paths); + let pdbs = find_pdbs(self.0.image_paths.iter().map(|entry| entry.value().clone())); let pdb_db: PdbDb = pdbs .iter() .filter_map(|(a, b, c, d)| d.make_context().ok().map(|d| (*a, (*a, *b, c.clone(), d))))