ElasticTransform in PyTorch (4)



This content originally appeared on DEV Community and was authored by Super Kai (Kazuya Ito)

Buy Me a Coffee☕

*Memos:

ElasticTransform() can do random morphological transformation for an image as shown below. *It’s about alpha and sigma argument:

from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import ElasticTransform
from torchvision.transforms.functional import InterpolationMode

origin_data = OxfordIIITPet(
    root="data",
    transform=None
)

a0s01_data = OxfordIIITPet( # `a` is alpha and `s` is sigma.
    root="data",
    transform=ElasticTransform(alpha=0, sigma=0.1)
    # transform=ElasticTransform(alpha=[0, 0], sigma=[0.1, 0.1])
)

a0s1_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=0, sigma=1)
)

a0s10_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=0, sigma=10)
)

a0s40_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=0, sigma=40)
)

a10s01_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=10, sigma=0.1)
    # transform=ElasticTransform(alpha=-10, sigma=0.1)
)

a10s1_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=10, sigma=1)
)

a10s10_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=10, sigma=10)
)

a10s40_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=10, sigma=40)
)

a100s01_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=100, sigma=0.1)
    # transform=ElasticTransform(alpha=-100, sigma=0.1)
)

a100s1_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=100, sigma=1)
)

a100s10_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=100, sigma=10)
)

a100s40_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=100, sigma=40)
)

a1000s01_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=1000, sigma=0.1)
    # transform=ElasticTransform(alpha=-1000, sigma=0.1)
)

a1000s1_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=1000, sigma=1)
)

a1000s10_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=1000, sigma=10)
)

a1000s40_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=1000, sigma=40)
)

a10000s01_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=10000, sigma=0.1)
    # transform=ElasticTransform(alpha=-10000, sigma=0.1)
)

a10000s1_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=10000, sigma=1)
)

a10000s10_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=10000, sigma=10)
)

a10000s40_data = OxfordIIITPet(
    root="data",
    transform=ElasticTransform(alpha=10000, sigma=40)
)

import matplotlib.pyplot as plt

def show_images1(data, main_title=None):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    for i, (im, _) in zip(range(1, 6), data):
        plt.subplot(1, 5, i)
        plt.imshow(X=im)
        plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images1(data=origin_data, main_title="origin_data")
print()
show_images1(data=a0s01_data, main_title="a0s01_data")
show_images1(data=a0s1_data, main_title="a0s1_data")
show_images1(data=a0s10_data, main_title="a0s10_data")
show_images1(data=a0s40_data, main_title="a0s40_data")
print()
show_images1(data=a10s01_data, main_title="a10s01_data")
show_images1(data=a10s1_data, main_title="a10s1_data")
show_images1(data=a10s10_data, main_title="a10s10_data")
show_images1(data=a10s40_data, main_title="a10s40_data")
print()
show_images1(data=a100s01_data, main_title="a100s01_data")
show_images1(data=a100s1_data, main_title="a100s1_data")
show_images1(data=a100s10_data, main_title="a100s10_data")
show_images1(data=a100s40_data, main_title="a100s40_data")
print()
show_images1(data=a1000s01_data, main_title="a1000s01_data")
show_images1(data=a1000s1_data, main_title="a1000s1_data")
show_images1(data=a1000s10_data, main_title="a1000s10_data")
show_images1(data=a1000s40_data, main_title="a1000s40_data")
print()
show_images1(data=a10000s01_data, main_title="a10000s01_data")
show_images1(data=a10000s1_data, main_title="a10000s1_data")
show_images1(data=a10000s10_data, main_title="a10000s10_data")
show_images1(data=a10000s40_data, main_title="a10000s40_data")

# ↓ ↓ ↓ ↓ ↓ ↓ The code below is identical to the code above. ↓ ↓ ↓ ↓ ↓ ↓
def show_images2(data, main_title=None, a=50, s=5, 
                 ip=InterpolationMode.BILINEAR, f=0):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    if main_title != "origin_data":
        for i, (im, _) in zip(range(1, 6), data):
            plt.subplot(1, 5, i)
            et = ElasticTransform(alpha=a, sigma=s,
                                  interpolation=ip, fill=f)
            plt.imshow(X=et(im))
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
    else:
        for i, (im, _) in zip(range(1, 6), data):
            plt.subplot(1, 5, i)
            plt.imshow(X=im)
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images2(data=origin_data, main_title="origin_data")
print()
show_images2(data=origin_data, main_title="a0s01_data", a=0, s=0.1)
show_images2(data=origin_data, main_title="a0s1_data", a=0, s=1)
show_images2(data=origin_data, main_title="a0s10_data", a=0, s=10)
show_images2(data=origin_data, main_title="a0s40_data", a=0, s=40)
print()
show_images2(data=origin_data, main_title="a10s01_data", a=10, s=0.1)
show_images2(data=origin_data, main_title="a10s1_data", a=10, s=1)
show_images2(data=origin_data, main_title="a10s10_data", a=10, s=10)
show_images2(data=origin_data, main_title="a10s40_data", a=10, s=40)
print()
show_images2(data=origin_data, main_title="a100s01_data", a=100, s=0.1)
show_images2(data=origin_data, main_title="a100s1_data", a=100, s=1)
show_images2(data=origin_data, main_title="a100s10_data", a=100, s=10)
show_images2(data=origin_data, main_title="a100s40_data", a=100, s=40)
print()
show_images2(data=origin_data, main_title="a1000s01_data", a=1000, s=0.1)
show_images2(data=origin_data, main_title="a1000s1_data", a=1000, s=1)
show_images2(data=origin_data, main_title="a1000s10_data", a=1000, s=10)
show_images2(data=origin_data, main_title="a1000s40_data", a=1000, s=40)
print()
show_images2(data=origin_data, main_title="a10000s01_data", a=10000, s=0.1)
show_images2(data=origin_data, main_title="a10000s1_data", a=10000, s=1)
show_images2(data=origin_data, main_title="a10000s10_data", a=10000, s=10)
show_images2(data=origin_data, main_title="a10000s40_data", a=10000, s=40)

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description


This content originally appeared on DEV Community and was authored by Super Kai (Kazuya Ito)