import torch import numpy as np import inspect def print_params(*args): """ Pretty-print multiple variables, recursively handling list/dict/torch.Tensor/numpy.ndarray/scalar, and automatically print the variable names. Usage: print_params(arc_equation, gt_mask_ends, gt_mask_params, arc_pos_matched_idxs) """ # Get the calling frame and local variables frame = inspect.currentframe().f_back local_vars = frame.f_locals # Try to match args with variable names in caller's scope arg_names = [] for arg in args: found = False for name, val in local_vars.items(): if val is arg: arg_names.append(name) found = True break if not found: arg_names.append("unknown") # Recursive printer def _print(obj, indent=0, max_elements=10): prefix = " " * indent if isinstance(obj, dict): print(f"{prefix}dict:") for k, v in obj.items(): print(f"{prefix} key: {k}") _print(v, indent + 2, max_elements) elif isinstance(obj, list): print(f"{prefix}list (len={len(obj)}):") for i, v in enumerate(obj): print(f"{prefix} [{i}]") _print(v, indent + 2, max_elements) elif isinstance(obj, torch.Tensor): if obj.numel() > max_elements: print(f"{prefix}Tensor(shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device}, values={obj.flatten()[:max_elements].tolist()} ...)") else: print(f"{prefix}Tensor(shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device}, values={obj.tolist()})") elif isinstance(obj, np.ndarray): if obj.size > max_elements: print(f"{prefix}ndarray(shape={obj.shape}, dtype={obj.dtype}, values={obj.flatten()[:max_elements]} ...)") else: print(f"{prefix}ndarray(shape={obj.shape}, dtype={obj.dtype}, values={obj.tolist()})") else: print(f"{prefix}{repr(obj)}") # Print each variable with its name for name, value in zip(arg_names, args): print(f"\n=== {name} ===") _print(value)