torch.tensor()
とtorch.Tensor()
の違いについての備忘録です。
はじめに
PyTorchでtensorを作るときはtorch.tensor()メソッドが使われることが多いですね。
一方でtorch.Tensor()のようにクラスのコンストラクタをそのまま呼び出してもtensorを作れるように見えます。
これらふたつ、
torch.tensor()
torch.Tensor()
の違いがよく分からなかったので備忘録です。
torch.tensorとtorch.Tensorの違い
一言で
基本torch.tensor()
を使いましょう。
データ型を推論してくれるので便利です。
現状torch.Tensor()
の方を使う理由はあまりありません。
詳しく
まず簡単に挙動を見てみます。
x = torch.tensor([1, 2, 3]) print(x) print(x.dtype) X = torch.Tensor([1, 2, 3]) print(X) print(X.dtype) # 実行結果 # tensor([1, 2, 3]) # torch.int64 # tensor([1., 2., 3.]) # torch.float32
torch.tensor([1, 2, 3])
でtensorを作った場合はデータ型がtorch.int64
になっていますが、torch.Tensor([1, 2, 3])
の場合はデータ型がtorch.float32
になっています。
これはtorch.tensor()
が渡されたdataの型を推論するのに対して、torch.tensor()
ではtorch.FloatTensor
を返すようになっているからです。
もちろん、torch.tensor()
を使う場合もdtype
引数でデータ型を指定することができます。
y = torch.tensor([1, 2, 3], dtype=torch.float32) print(y) print(y.dtype) # 実行結果 # tensor([1., 2., 3.]) # torch.float32
ということで、基本はtorch.tensor()
を使う方が融通が効きます。
torch.Tensorのドキュメントにも、値を渡してtensorを作るときはtorch.tensor()
が推奨である旨記載があります。
To create a tensor with pre-existing data, use torch.tensor().
補足: 空のtensorを作るには
torch.tensor()
で空のtensorを作ろうとすると、一見してエラーが発生します。
empty_err = torch.tensor() print(empty_err) print(empty_err.dtype) # 実行結果: Error Traceback (most recent call last): File "/workspaces/python-examples/torch_tensor/main.py", line 25, in <module> empty_err = torch.tensor() TypeError: tensor() missing 1 required positional arguments: "data" # torch.Tensor()ではエラーは発生しない empty = torch.Tensor() print(empty) print(empty.dtype) # 実行結果 tensor([]) torch.float32
では空のtensorを作るときはtorch.Tensor()
を使った方が良いのかというと、そうではありません。
torch.tensor(())
とすることで、空のtensorを作成できます。
empty = torch.tensor(()) print(empty) print(empty.dtype) # 実行結果 tensor([]) torch.float32
おわりに
以上、torch.tensorとtorch.Tensorの違いをメモしました。
どなたかの参考になれば幸いです。
[関連記事]