import torch from matplotlib import pyplot as plt img=torch.ones((128,128,3)) mask=torch.zeros((128,128,3)) mask[0:30,:,:]=1 img[mask==1]=0 plt.imshow(img) plt.show()