| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import torch
- import numpy as np
- import inspect
- def print_params(*args):
- """
- Pretty-print multiple variables, recursively handling
- list/dict/torch.Tensor/numpy.ndarray/scalar, and automatically
- print the variable names.
- Usage:
- print_params(arc_equation, gt_mask_ends, gt_mask_params, arc_pos_matched_idxs)
- """
- # Get the calling frame and local variables
- frame = inspect.currentframe().f_back
- local_vars = frame.f_locals
- # Try to match args with variable names in caller's scope
- arg_names = []
- for arg in args:
- found = False
- for name, val in local_vars.items():
- if val is arg:
- arg_names.append(name)
- found = True
- break
- if not found:
- arg_names.append("unknown")
- # Recursive printer
- def _print(obj, indent=0, max_elements=10):
- prefix = " " * indent
- if isinstance(obj, dict):
- print(f"{prefix}dict:")
- for k, v in obj.items():
- print(f"{prefix} key: {k}")
- _print(v, indent + 2, max_elements)
- elif isinstance(obj, list):
- print(f"{prefix}list (len={len(obj)}):")
- for i, v in enumerate(obj):
- print(f"{prefix} [{i}]")
- _print(v, indent + 2, max_elements)
- elif isinstance(obj, torch.Tensor):
- if obj.numel() > max_elements:
- print(f"{prefix}Tensor(shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device}, values={obj.flatten()[:max_elements].tolist()} ...)")
- else:
- print(f"{prefix}Tensor(shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device}, values={obj.tolist()})")
- elif isinstance(obj, np.ndarray):
- if obj.size > max_elements:
- print(f"{prefix}ndarray(shape={obj.shape}, dtype={obj.dtype}, values={obj.flatten()[:max_elements]} ...)")
- else:
- print(f"{prefix}ndarray(shape={obj.shape}, dtype={obj.dtype}, values={obj.tolist()})")
- else:
- print(f"{prefix}{repr(obj)}")
- # Print each variable with its name
- for name, value in zip(arg_names, args):
- print(f"\n=== {name} ===")
- _print(value)
|