私の備忘録がないわね...私の...

画像処理とかプログラミングのお話。

pytorchで複数の画像の複数のピクセルを同時に変更する

題名が言いたいこと。

複数の画像が存在し、それぞれの画像の複数のピクセルの値を変更したい。このときピクセルの場所, 変更後のピクセルの値はそれぞれ異なる。

for文を回せば簡単なのですが、pytorchやnumpyを使う時はできるだけスライスを使いたいです。

今回は以下のような5枚の画像, RGB, 縦・横2x2のデータを考えます。

data = torch.zeros((2, 3, 4, 4))
tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]],


        [[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]])

またそれぞれの画像の変えたい場所と値は以下のように定義します。ピクセルの中身の5つの要素はx, y, R, G, Bを表しています。

point = torch.tensor(
    [
        [ # 1枚目の画像
            [0, 0, 100, 200, 300], # 1個目のピクセル
            [1, 1, 400, 500, 600] # 2個目のピクセル
        ],
        [ # 2枚目の画像
            [0, 1, 700, 800, 900], # 1個目のピクセル
            [1, 0, 1000, 1100, 1200] # 2個目のピクセル
        ]
    ]
)

for文を回す

上でforは使いたくないと言いつつ、一応書いてみます。

a = data.clone()
# 画像
for i in range(len(point)):
    # ピクセル
    for j in range(len(point[i])):
        # RGB
        for k in range(3):
            a[i, k, point[i, j, 1], point[i, j, 0]] = point[i, j, k+2]
tensor([[[[ 100.,    0.,    0.,    0.],
          [   0.,  400.,    0.,    0.],
          [   0.,    0.,    0.,    0.],
          [   0.,    0.,    0.,    0.]],

         [[ 200.,    0.,    0.,    0.],
          [   0.,  500.,    0.,    0.],
          [   0.,    0.,    0.,    0.],
          [   0.,    0.,    0.,    0.]],

         [[ 300.,    0.,    0.,    0.],
          [   0.,  600.,    0.,    0.],
          [   0.,    0.,    0.,    0.],
          [   0.,    0.,    0.,    0.]]],


        [[[   0., 1000.,    0.,    0.],
          [ 700.,    0.,    0.,    0.],
          [   0.,    0.,    0.,    0.],
          [   0.,    0.,    0.,    0.]],

         [[   0., 1100.,    0.,    0.],
          [ 800.,    0.,    0.,    0.],
          [   0.,    0.,    0.,    0.],
          [   0.,    0.,    0.,    0.]],

         [[   0., 1200.,    0.,    0.],
          [ 900.,    0.,    0.,    0.],
          [   0.,    0.,    0.,    0.],
          [   0.,    0.,    0.,    0.]]]])

成功はしてますが、もう少し賢く書きたいです。

工夫1

例えばRGBのforをスライスに変えましょう。

b = data.clone()
# 画像
for i in range(len(point)):
    # ピクセル
    for j in range(len(point[i])):
        b[i, :, point[i, j, 1], point[i, j, 0]] = point[i, j, 2:]

結果は省きますが、これはうまく行きます。ただもう少しforを取り除きたいです。

工夫2

では以下はどうでしょうか?今度はピクセルをスライスでやろうとしてます。

c = data.clone()
# 画像
for i in range(len(point)):
    c[i, :, point[i, :, 1], point[i, :, 0]] = point[i, :, 2:]

これは以下のようなエラーがでます。

RuntimeError: shape mismatch: value tensor of shape [2, 3] cannot be broadcast to indexing result of shape [3, 2]

スライスを2個以上使うには単純に書き換えるだけでは上手くいかなさそうです。ただ今回はRGBとかの次元を合わせれば良いだけなのでpermute関数で軸を変えてやれば良さそうです。

d = data.clone()
# 画像
for i in range(len(point)):
    d[i, :, point[i, :, 1], point[i, :, 0]] = point[i, :, 2:].permute(1, 0).type(torch.float)

これは上手く行きます。typeは型で怒られたので入れてます。

工夫3

ここまでできたのならforは使いたくないです。直感的に書き換えてみましょう。

e = data.clone()
e[:, :, point[:, :, 1], point[:, :, 0]] = point[:, :, 2:]

まぁ当然エラーが出ます。そして今回は単純に軸を変えるだけでは上手くいかなさそうです。

RuntimeError: shape mismatch: value tensor of shape [2, 2, 3] cannot be broadcast to indexing result of shape [2, 3, 2, 2]

これは以下のようにすると上手くいきます。

e = data.clone()
image_number = [[0, 0], [1, 1]]
e[image_number, :, point[:, :, 1], point[:, :, 0]] = point[:, :, 2:].type(torch.float)

画像の指定方法とpoint[:, :, 0], point[:, :, 1]の形を合わせてます。