PyTorchでout of memoryになる

推論時に徐々にメモリ使用量が増えてout of memoryになる

推論時にout of memory になった。

out of memoryになったコードのイメージ。

for batch in dataloader:
    inputs = {"input_ids": batch[0]}
    outputs = model(**inputs)

出力されたエラー。
GPUのメモリ使用量をモニタリングすると、ループする度に使用するメモリ量が増えていき、out of memoryになる。

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 15.78 GiB total capacity; 14.43 GiB already allocated; 19.12 MiB free; 268.73 MiB cached)

解決策

with torch.no_grad(): を付ける。
(detachingでもよいが、lossが不要ならこっちのほうが簡単)

修正後のコードのイメージ。

with torch.no_grad():
    for batch in dataloader:
        inputs = {"input_ids": batch[0]}
        outputs = self.model(**inputs)

以下の投稿を見ると、lossのdetachもしくはwith torch.no_grad():でラップのどちらも行っていない場合、loss Tensorがcomputation graphを保持し続けるためにiteration毎にメモリ使用量が増加するとのこと。

Increasing memory usage for forward pass - PyTorch Forums