show_prams.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import torch
  2. import numpy as np
  3. import inspect
  4. def print_params(*args):
  5. """
  6. Pretty-print multiple variables, recursively handling
  7. list/dict/torch.Tensor/numpy.ndarray/scalar, and automatically
  8. print the variable names.
  9. Usage:
  10. print_params(arc_equation, gt_mask_ends, gt_mask_params, arc_pos_matched_idxs)
  11. """
  12. # Get the calling frame and local variables
  13. frame = inspect.currentframe().f_back
  14. local_vars = frame.f_locals
  15. # Try to match args with variable names in caller's scope
  16. arg_names = []
  17. for arg in args:
  18. found = False
  19. for name, val in local_vars.items():
  20. if val is arg:
  21. arg_names.append(name)
  22. found = True
  23. break
  24. if not found:
  25. arg_names.append("unknown")
  26. # Recursive printer
  27. def _print(obj, indent=0, max_elements=10):
  28. prefix = " " * indent
  29. if isinstance(obj, dict):
  30. print(f"{prefix}dict:")
  31. for k, v in obj.items():
  32. print(f"{prefix} key: {k}")
  33. _print(v, indent + 2, max_elements)
  34. elif isinstance(obj, list):
  35. print(f"{prefix}list (len={len(obj)}):")
  36. for i, v in enumerate(obj):
  37. print(f"{prefix} [{i}]")
  38. _print(v, indent + 2, max_elements)
  39. elif isinstance(obj, torch.Tensor):
  40. if obj.numel() > max_elements:
  41. print(f"{prefix}Tensor(shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device}, values={obj.flatten()[:max_elements].tolist()} ...)")
  42. else:
  43. print(f"{prefix}Tensor(shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device}, values={obj.tolist()})")
  44. elif isinstance(obj, np.ndarray):
  45. if obj.size > max_elements:
  46. print(f"{prefix}ndarray(shape={obj.shape}, dtype={obj.dtype}, values={obj.flatten()[:max_elements]} ...)")
  47. else:
  48. print(f"{prefix}ndarray(shape={obj.shape}, dtype={obj.dtype}, values={obj.tolist()})")
  49. else:
  50. print(f"{prefix}{repr(obj)}")
  51. # Print each variable with its name
  52. for name, value in zip(arg_names, args):
  53. print(f"\n=== {name} ===")
  54. _print(value)