pytorchで複数の画像の複数のピクセルを同時に変更する
題名が言いたいこと。
複数の画像が存在し、それぞれの画像の複数のピクセルの値を変更したい。このときピクセルの場所, 変更後のピクセルの値はそれぞれ異なる。
for
文を回せば簡単なのですが、pytorchやnumpyを使う時はできるだけスライスを使いたいです。
今回は以下のような5枚の画像, RGB, 縦・横2x2のデータを考えます。
data = torch.zeros((2, 3, 4, 4))
tensor([[[[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]], [[[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]]])
またそれぞれの画像の変えたい場所と値は以下のように定義します。ピクセルの中身の5つの要素はx, y, R, G, Bを表しています。
point = torch.tensor( [ [ # 1枚目の画像 [0, 0, 100, 200, 300], # 1個目のピクセル [1, 1, 400, 500, 600] # 2個目のピクセル ], [ # 2枚目の画像 [0, 1, 700, 800, 900], # 1個目のピクセル [1, 0, 1000, 1100, 1200] # 2個目のピクセル ] ] )
for
文を回す
上でfor
は使いたくないと言いつつ、一応書いてみます。
a = data.clone() # 画像 for i in range(len(point)): # ピクセル for j in range(len(point[i])): # RGB for k in range(3): a[i, k, point[i, j, 1], point[i, j, 0]] = point[i, j, k+2]
tensor([[[[ 100., 0., 0., 0.], [ 0., 400., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]], [[ 200., 0., 0., 0.], [ 0., 500., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]], [[ 300., 0., 0., 0.], [ 0., 600., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]]], [[[ 0., 1000., 0., 0.], [ 700., 0., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]], [[ 0., 1100., 0., 0.], [ 800., 0., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]], [[ 0., 1200., 0., 0.], [ 900., 0., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]]]])
成功はしてますが、もう少し賢く書きたいです。
工夫1
例えばRGBのfor
をスライスに変えましょう。
b = data.clone() # 画像 for i in range(len(point)): # ピクセル for j in range(len(point[i])): b[i, :, point[i, j, 1], point[i, j, 0]] = point[i, j, 2:]
結果は省きますが、これはうまく行きます。ただもう少しfor
を取り除きたいです。
工夫2
では以下はどうでしょうか?今度はピクセルをスライスでやろうとしてます。
c = data.clone() # 画像 for i in range(len(point)): c[i, :, point[i, :, 1], point[i, :, 0]] = point[i, :, 2:]
これは以下のようなエラーがでます。
RuntimeError: shape mismatch: value tensor of shape [2, 3] cannot be broadcast to indexing result of shape [3, 2]
スライスを2個以上使うには単純に書き換えるだけでは上手くいかなさそうです。ただ今回はRGBとかの次元を合わせれば良いだけなのでpermute
関数で軸を変えてやれば良さそうです。
d = data.clone() # 画像 for i in range(len(point)): d[i, :, point[i, :, 1], point[i, :, 0]] = point[i, :, 2:].permute(1, 0).type(torch.float)
これは上手く行きます。type
は型で怒られたので入れてます。
工夫3
ここまでできたのならfor
は使いたくないです。直感的に書き換えてみましょう。
e = data.clone() e[:, :, point[:, :, 1], point[:, :, 0]] = point[:, :, 2:]
まぁ当然エラーが出ます。そして今回は単純に軸を変えるだけでは上手くいかなさそうです。
RuntimeError: shape mismatch: value tensor of shape [2, 2, 3] cannot be broadcast to indexing result of shape [2, 3, 2, 2]
これは以下のようにすると上手くいきます。
e = data.clone() image_number = [[0, 0], [1, 1]] e[image_number, :, point[:, :, 1], point[:, :, 0]] = point[:, :, 2:].type(torch.float)
画像の指定方法とpoint[:, :, 0]
, point[:, :, 1]
の形を合わせてます。