自己寫了個(gè)dataloader,為了部署方便,用OpenCV的接口進(jìn)行數(shù)據(jù)讀取,而沒有用PIL,代碼大致如下:
def __getitem__(self, idx): sample = self.samples[idx] img = cv2.imread(sample[0]) img = cv2.resize(img, tuple(self.input_size)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # if not self.val and random.randint(1, 10) 3: # img = self.img_aug(img) img = Image.fromarray(img) img = self.transforms(img) ...
結(jié)果在訓(xùn)練過程中,在第1個(gè)epoch的最后一個(gè)batch時(shí),程序卡死。
可能是因?yàn)镺penCV與Pytorch互鎖的問題,關(guān)閉OpenCV的多線程,問題解決。
cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False)
補(bǔ)充:pytorch 中一個(gè)batch的訓(xùn)練過程
# 一般情況下 optimizer.zero_grad() # 梯度清零 preds = model(inputs) # inference,前向傳播求出預(yù)測(cè)值 loss = criterion(preds, targets) # 計(jì)算loss loss.backward() # 反向傳播求解梯度 optimizer.step() # 更新權(quán)重,更新網(wǎng)絡(luò)權(quán)重參數(shù)
此外,反向傳播前,如果不進(jìn)行梯度清零,則可以實(shí)現(xiàn)梯度累加,從而一定程度上解決顯存受限的問題。
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
標(biāo)簽:江蘇 蘭州 駐馬店 六盤水 常州 成都 宿遷 山東
巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《Pytorch dataloader在加載最后一個(gè)batch時(shí)卡死的解決》,本文關(guān)鍵詞 Pytorch,dataloader,在,加載,;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無關(guān)。