文章目錄[隱藏]
百度搜索技術創新挑戰賽(簡稱STI)是由百度搜索發起,聯合四大區域高校、學會(hui) 共同舉(ju) 辦的一項全國性科技競賽,共有兩(liang) 個(ge) 賽道。本賽道希望從(cong) 答案抽取和答案檢驗兩(liang) 個(ge) 方麵調研真實網絡環境下的文檔級機器閱讀理解技術,以求進一步提升深度智能問答效果,給用戶提供更好的搜索體(ti) 驗。
賽題背景
近年來,隨著機器閱讀理解與(yu) 深度預訓練模型等相關(guan) 技術的發展,抽取式智能問答係統的性能取得了非常明顯的提升。然而,在開放領域的搜索場景下得到的網頁數據會(hui) 非常複雜,其中往往存在著網頁文檔質量參差不齊、長短不一,問題答案分布零散、長度較長等問題,給答案抽取和答案置信度計算帶來了較大挑戰。
本賽題希望從(cong) 答案抽取和答案檢驗兩(liang) 個(ge) 方麵調研真實網絡環境下的文檔級機器閱讀理解技術,以求進一步提升深度智能問答效果,給用戶提供更好的搜索體(ti) 驗。
任務概述
本次任務共分為(wei) 兩(liang) 個(ge) 子任務,分別涉及基於(yu) 複雜網頁文檔內(nei) 容的答案抽取和答案檢驗技術,需全部完成。請用飛槳 AI Studio配置的NVIDIA A100完成參賽作品。
排名計算:選手根據提交要求將結果提交至AI Studio後,區域賽將基於(yu) 兩(liang) 個(ge) 任務的打榜結果加權平均選出前N名,無需評審。決(jue) 賽將基於(yu) 軟件延展開發、技術深度、創新性打分和打榜結果最終確定獲獎隊伍,決(jue) 賽將有專(zhuan) 家評審。
任務概述
當前基於(yu) 深度預訓練模型的機器閱讀理解方案在段落級智能問答任務上取得了非常好的性能,但在真實數據環境下的文檔級閱讀理解任務上的表現仍然難以令人滿意。如何在文檔長度不定,答案長度較長的數據環境中取得良好且魯棒的答案抽取效果是子任務1關(guan) 注的重點。 任務定義(yi)
給定一個(ge) 用戶搜索問題集合Q,基於(yu) 每個(ge) 搜索問題q,給定搜索引擎檢索得到的網頁文檔集合Dq,其中包括最多40個(ge) 網頁文檔。針對每個(ge) q-d對,要求參評係統從(cong) d中抽取能夠回答q的答案片段a。同一文檔中的答案可能為(wei) 不連續的多個(ge) 片段,文檔中也可能不包含答案。
數據集
訓練集包含約900個(ge) query、30000個(ge) query-document對;驗證集和測試集各包含約100個(ge) query,3000個(ge) query-document對。數據的主要特點為(wei) :
- 文檔長度普遍較長,質量參差不齊,內(nei) 部往往包含大量噪聲 - 句子級別答案片段,通常由包含完整上下文的若幹句子組成 - 標注數據隻保證答案片段與(yu) 搜索問題間的相關(guan) 性,不保證正確性,且存在不包含答案的文檔
數據樣例
問題q:備孕偶爾喝冰的可以嗎 篇章d:備孕能吃冷的食物嗎 炎熱的夏天讓很多人都覺得悶熱...,下麵一起來看看吧! 備孕能吃冷的食物嗎 在中醫養(yang) 生中,女性體(ti) 質屬陰,不可以貪涼。吃了過多寒涼、生冷的食物後,會(hui) 消耗陽氣,導致寒邪內(nei) 生,侵害子宮。另外,宮寒是腎陽虛的表現,不會(hui) 直接導致不孕。但宮寒會(hui) 引起婦科疾病,所以也不可不防。因此處於(yu) 備孕期的女性最好不要吃冷的食物。 備孕食譜有哪些 ... 答案a:在中醫養(yang) 生中,女性體(ti) 質屬陰,不可以貪涼。吃了過多寒涼、生冷的食物後,會(hui) 消耗陽氣,導致寒邪內(nei) 生,侵害子宮。另外,宮寒是腎陽虛的表現,不會(hui) 直接導致不孕。但宮寒會(hui) 引起婦科疾病,所以也不可不防。因此處於(yu) 備孕期的女性最好不要吃冷的食物。
評價(jia) 指標
計算基於(yu) 每個(ge) query-document對模型預測答案與(yu) 標注答案間字粒度的準確、召回、F1值,取所有測試數據的平均F1作為(wei) 最終指標。對於(yu) 不包含答案的文檔,其答案可看做一個(ge) 特殊token【無答案】,若模型預測出答案,其F1為(wei) 0,若模型預測【無答案】,其F1為(wei) 1。
方案介紹
賽題可以視為(wei) 基礎的信息抽取任務,也可以直接視為(wei) 問答類型的信息抽取問題。我們(men) 需要構建一個(ge) 模型,根據query從(cong) document中找到想要的答案。、
如果我們(men) 使用BERT 或者 ERNIE 可以直接參考如下思路,對於(yu) 模型的輸出可以直接輸出對應的兩(liang) 個(ge) 位置,對應為(wei) 回答的開始位置和結束位置。
這裏需要深入一下模型的實現細節:
- query和documnet是一起輸入給模型,一般情況下query在前麵。
- 回答對應的輸出可以通過模型輸出後的全連接層完成分類,當然回歸也可以。
如果采用QA的思路,則需要將比賽數據集轉換為(wei) QA的格式,特別是文本的處理:長文本需要進行截斷。
方案代碼
步驟1:解壓數據集
!pip install paddle-ernie > log.log # !cp data/data174963/data_task1.tar /home/aistudio/ !tar -xf /home/aistudio/data_task1.tar
步驟2:讀取數據集
# 導入常見的庫 import numpy as np import pandas as pd import os, sys, json
# 讀取訓練集、測試集和驗證集 train_json = pd.read_json('data_task1/train_data/train.json', lines=True) test_json = pd.read_json('data_task1/test_data/test.json', lines=True) dev_json = pd.read_json('data_task1/dev_data/dev.json', lines=True)
# 查看數據集樣例 train_json.head(1) 0 answer_list:[渝北區,重慶市轄區,屬重慶主城區、重慶大都市區,地處重慶市西北部。東(dong) 鄰長壽區、南與(yu) 江北區毗... title:渝北區_百度百科 url:https://baike.baidu***.com/item/渝北區/2531151 answer_start_list:[4] doc_text:渝北區 渝北區,重慶市轄區,屬重慶主城區、重慶大都市區,地處重慶市西北部。東(dong) 鄰長壽區、南與(yu) 江... query:渝北區麵積 org_answer:渝北區,重慶市轄區,屬重慶主城區、重慶大都市區,地處重慶市西北部。東(dong) 鄰長壽區、南與(yu) 江北區毗鄰...
train_json.iloc[0]
answer_list [渝北區,重慶市轄區,屬重慶主城區、重慶大都市區,地處重慶市西北部。東(dong) 鄰長壽區、南與(yu) 江北區毗... title 渝北區_百度百科 url https://baike.baidu***.com/item/渝北區/2531151 answer_start_list [4] doc_text 渝北區 渝北區,重慶市轄區,屬重慶主城區、重慶大都市區,地處重慶市西北部。東(dong) 鄰長壽區、南與(yu) 江... query 渝北區麵積 org_answer 渝北區,重慶市轄區,屬重慶主城區、重慶大都市區,地處重慶市西北部。東(dong) 鄰長壽區、南與(yu) 江北區毗鄰... Name: 0, dtype: object
這裏的數據集為(wei) 如下格式:
- query:用戶搜索query
- title:網頁標題
- url:網頁網頁
- doc_tet:網頁文檔正文
- answer_list:文檔中包含的答案,可能有多條不連續的答案,以列表存儲
- answer_start_list:答案開始的字符位置
- org_answer:合並文檔中所有答案,用於評測,無答案的為NoAnswer
訓練集和驗證集將包含上麵所有字段,而測試集將隻包含title、url、doc_text和query
test_json.iloc[0]
title Win11藍牙鼠標連接不上電腦怎麽(me) 辦-路由器之家 url https://www.wanqh.com/176866.html doc_text Win11藍牙鼠標連接不上電腦怎麽(me) 辦 路由器網投稿:文章是關(guan) 於(yu) "Win11藍牙鼠標連接不上... query 連接不到外設是win11問題嗎 Name: 0, dtype: object
步驟3:加載ERNIE模型
這裏我們(men) 使用paddlenlp==2.0.7,當然你也可以選擇更高的版本。更高的版本會(hui) 將損失計算也封裝進去,其他的部分區別不大。
import paddle import paddlenlp print('paddle version', paddle.__version__) print('paddlenlp version', paddlenlp.__version__)
from paddlenlp.transformers import ErnieForQuestionAnswering, ErnieTokenizer
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
model = ErnieForQuestionAnswering.from_pretrained('ernie-1.0')
paddle version 2.1.2 paddlenlp version 2.0.0
[2022-11-13 16:22:45,093] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-1.0/vocab.txt
[2022-11-13 16:22:45,106] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-1.0/ernie_v1_chn_base.pdparams
# 對文檔的文檔進行劃分、計算文檔的長度 train_json['doc_sentence'] = train_json['doc_text'].str.split('。') train_json['doc_sentence_length'] = train_json['doc_sentence'].apply(lambda doc: [len(sentence) for sentence in doc]) train_json['doc_sentence_length_max'] = train_json['doc_sentence_length'].apply(max) train_json = train_json[train_json['doc_sentence_length_max'] < 10000] # 刪除了部分超長文檔
# 對文檔的文檔進行劃分、計算文檔的長度 dev_json['doc_sentence'] = dev_json['doc_text'].str.split('。')
dev_json['doc_sentence_length'] = dev_json['doc_sentence'].apply(lambda doc: [len(sentence) for sentence in doc])
dev_json['doc_sentence_length_max'] = dev_json['doc_sentence_length'].apply(max)
dev_json = dev_json[dev_json['doc_sentence_length_max'] < 10000] # 刪除了部分超長文檔
# 對文檔的文檔進行劃分、計算文檔的長度 test_json['doc_sentence'] = test_json['doc_text'].str.split('。')
test_json['doc_sentence_length'] = test_json['doc_sentence'].apply(lambda doc: [len(sentence) for sentence in doc])
test_json['doc_sentence_length_max'] = test_json['doc_sentence_length'].apply(max)
train_json.iloc[10]
answer_list [渝北區位於(yu) 重慶市北部,長江北岸,嘉陵江下遊東(dong) 岸的三角地帶,幅員麵積1452平方公裏,有耕地... title 重慶市渝北區地名介紹 url https://www.tcmap.com.cn/chongqing/yubeiqu.html answer_start_list [788] doc_text 重慶市渝北區 [移動版] 2022年4月,渝北區被確定為(wei) “十四五”時期“無廢城市”建設名... query 渝北區麵積 org_answer 渝北區位於(yu) 重慶市北部,長江北岸,嘉陵江下遊東(dong) 岸的三角地帶,幅員麵積1452平方公裏,有耕地6... doc_sentence [重慶市渝北區 [移動版] 2022年4月,渝北區被確定為(wei) “十四五”時期“無廢城市”建設... doc_sentence_length [47, 44, 78, 41, 46, 39, 46, 33, 31, 38, 46, 3... doc_sentence_length_max 418 Name: 10, dtype: object
test_json.iloc[10]
title Win11連接無線鼠標沒反應什麽(me) 原因_Win11連接無線鼠標沒反應的解決(jue) 技巧_U教授 url https://www.ujiaoshou.com/xtjc/154037266.html doc_text Win11連接無線鼠標沒反應什麽(me) 原因 Win11連接無線鼠標沒反應的解決(jue) 技巧 電腦升級成w... query 連接不到外設是win11問題嗎 doc_sentence [Win11連接無線鼠標沒反應什麽(me) 原因 Win11連接無線鼠標沒反應的解決(jue) 技巧 電腦升級成... doc_sentence_length [87, 53, 67, 39, 42, 11, 38, 31, 19, 23, 38, 3... doc_sentence_length_max 87 Name: 10, dtype: object
步驟4:構建數據集
接下來需要構建QA任務的數據集,這裏的數據集需要處理為(wei) 如下的格式:
query [SEP] sentence of document
- 訓練集數據集處理
train_encoding = []
# for idx in range(len(train_json)): for idx in range(10000):
# 讀取原始數據的一條樣本 title = train_json.iloc[idx]['title']
answer_start_list = train_json.iloc[idx]['answer_start_list']
answer_list = train_json.iloc[idx]['answer_list']
doc_text = train_json.iloc[idx]['doc_text']
query = train_json.iloc[idx]['query']
doc_sentence = train_json.iloc[idx]['doc_sentence'] # 對於(yu) 文章中的每個(ge) 句子 for sentence in set(doc_sentence):
# 如果存在答案 for answer in answer_list:
answer = answer.strip("。") # 如果問題 + 答案 太長,跳過 if len(query + sentence) > 512: continue # 對問題 + 答案進行編碼 encoding = tokenizer.encode(query, sentence, max_seq_len=512, return_length=True,
return_position_ids=True, pad_to_max_seq_len=True, return_attention_mask=True) # 如果答案在這個(ge) 句子中,找到start 和 end的 位置 if answer in sentence:
encoding['start_positions'] = len(query) + 2 + sentence.index(answer)
encoding['end_positions'] = len(query) + 2 + sentence.index(answer) + len(answer)
# 如果不存在,則位置設置為(wei) 0 else:
encoding['start_positions'] = 0 encoding['end_positions'] = 0 # 存儲(chu) 正樣本 if encoding['start_positions'] != 0:
train_encoding.append(encoding) # 對負樣本進行采樣,因為(wei) 負樣本太多 # 正樣本:query + sentence -> answer 的情況 # 負樣本:query + sentence -> No answer 的情況 if encoding['start_positions'] == 0 and np.random.randint(0, 100) > 99:
train_encoding.append(encoding) if len(train_encoding) % 500 == 0:
print(len(train_encoding))
- 驗證集數據集處理
val_encoding = []
for idx in range(len(dev_json)): # for idx in range(200): title = dev_json.iloc[idx]['title']
answer_start_list = dev_json.iloc[idx]['answer_start_list']
answer_list = dev_json.iloc[idx]['answer_list']
doc_text = dev_json.iloc[idx]['doc_text']
query = dev_json.iloc[idx]['query']
doc_sentence = dev_json.iloc[idx]['doc_sentence']
for sentence in set(doc_sentence): for answer in answer_list:
answer = answer.strip("。")
if len(query + sentence) > 512: continue encoding = tokenizer.encode(query, sentence, max_seq_len=512, return_length=True,
return_position_ids=True, pad_to_max_seq_len=True, return_attention_mask=True)
if answer in sentence:
encoding['start_positions'] = len(query) + 2 + sentence.index(answer)
encoding['end_positions'] = len(query) + 2 + sentence.index(answer) + len(answer) else:
encoding['start_positions'] = 0 encoding['end_positions'] = 0 if encoding['start_positions'] != 0:
val_encoding.append(encoding)
if encoding['start_positions'] == 0 and np.random.randint(0, 100) > 99:
val_encoding.append(encoding)
- 測試集數據集處理
test_encoding = [] test_raw_txt = [] for idx in range(len(test_json)): title = test_json.iloc[idx]['title'] doc_text = test_json.iloc[idx]['doc_text'] query = test_json.iloc[idx]['query'] doc_sentence = test_json.iloc[idx]['doc_sentence']
for sentence in set(doc_sentence): if len(query + sentence) > 512: continue encoding = tokenizer.encode(query, sentence, max_seq_len=512, return_length=True,
return_position_ids=True, pad_to_max_seq_len=True, return_attention_mask=True)
test_encoding.append(encoding)
test_raw_txt.append(
[idx, query, sentence]
)
步驟5:批量數據讀取
# 手動將數據集進行批量打包 def data_generator(data_encoding, batch_size = 6): for idx in range(len(data_encoding) // batch_size): batch_data = data_encoding[idx * batch_size : (idx+1) * batch_size] batch_encoding = {} for key in batch_data[0].keys(): if key == 'seq_len': continue batch_encoding[key] = paddle.to_tensor(np.array([x[key] for x in batch_data])) yield batch_encoding
步驟6:模型訓練與(yu) 驗證
# 優(you) 化器 optimizer = paddle.optimizer.SGD(0.0005, parameters=model.parameters())
# 損失函數 loss_fct = paddle.nn.CrossEntropyLoss()
best_val_start_acc = 0
for epoch in range(10): # 每次打亂(luan) 訓練集,防止過擬合 np.random.shuffle(train_encoding) # 訓練部分 train_loss = [] for batch_encoding in data_generator(train_encoding, 10):
# ERNIE正向傳(chuan) 播 start_logits, end_logits = model(batch_encoding['input_ids'], batch_encoding['token_type_ids'])
# 計算損失 start_loss = loss_fct(start_logits, batch_encoding['start_positions'])
end_loss = loss_fct(end_logits, batch_encoding['end_positions'])
total_loss = (start_loss + end_loss) / 2 # 參數更新 total_loss.backward()
train_loss.append(total_loss)
optimizer.step()
optimizer.clear_gradients() # 驗證部分 val_start_acc = []
val_end_acc = [] with paddle.no_grad(): for batch_encoding in data_generator(val_encoding, 10):
# ERNIE正向傳(chuan) 播 start_logits, end_logits = model(batch_encoding['input_ids'], batch_encoding['token_type_ids'])
# 計算識別精度 start_acc = paddle.mean((start_logits.argmax(1) == batch_encoding['start_positions']).astype(float))
end_acc = paddle.mean((end_logits.argmax(1) == batch_encoding['end_positions']).astype(float))
val_start_acc.append(start_acc)
val_end_acc.append(end_acc)
# 轉換數據格式為(wei) float train_loss = paddle.to_tensor(train_loss).mean().item()
val_start_acc = paddle.to_tensor(val_start_acc).mean().item()
val_end_acc = paddle.to_tensor(val_end_acc).mean().item() # 存儲(chu) 最優(you) 模型 if val_start_acc > best_val_start_acc:
paddle.save(model.state_dict(), 'model.pkl')
best_val_start_acc = val_start_acc # 每個(ge) epoch打印輸出結果 print(f'Epoch {epoch}, {train_loss:3f}, {val_start_acc:3f}/{val_end_acc:3f}')
# 關(guan) 閉dropout model = model.eval()
步驟7:模型預測
test_start_idx = [] test_end_idx = []
# 對測試集中query 和 sentence的情況進行預測 with paddle.no_grad(): for batch_encoding in data_generator(test_encoding, 12):
start_logits, end_logits = model(batch_encoding['input_ids'], batch_encoding['token_type_ids'])
test_start_idx += start_logits.argmax(1).tolist()
test_end_idx += end_logits.argmax(1).tolist() if len(test_start_idx) % 500 == 0:
print(len(test_start_idx), len(test_encoding))
test_submit = [''] * len(test_json)
# 對預測結果進行後處理 for (idx, query, sentence), st_idx, end_idx in zip(test_raw_txt, test_start_idx, test_end_idx):
# 如果start 或 end位置識別失敗,或 start位置 晚於(yu) end位置 if st_idx == 0 or end_idx == 0 or st_idx >= end_idx: continue # 如果start位置在query部分 if st_idx - len(query) - 2 < 0: continue test_submit[idx] += sentence[st_idx - len(query) - 2: end_idx - len(query) - 2]
# 生成提交結果 with open('subtask1_test_pred.txt', 'w') as up: for x in test_submit: if x == '': up.write('1tNoAnswern') else: up.write('1t'+x+'n')
方案總結
- 借助QA的思路,本代碼可以使用ERNIE快速完成模型訓練與(yu) 提交。
- 本思路可以在測試集上進行預測,預測步驟需要11分鍾。
- 模型優(you) 化器和學習(xi) 率已經調節多次,後續也可以自己調節。
改進方向
從(cong) 精度改變大小,可以從(cong) 以下幾個(ge) 角度改進:訓練數據 > 數據處理 > 模型與(yu) 預訓練 > 模型集成
- 訓練數據:使用全量的訓練數據
- 數據處理:對文檔進行切分,現在使用。進行切分,後續也可以嚐試其他。
- 模型與預處理:嚐試ERNIE版本,或者進行預訓練。
- 模型集成:
- 嚐試不同的數據劃分得到不同的模型
- 嚐試不同的文本處理方法得到不同的模型
當然也可以考慮其他數據,如不同的網頁擁有答案的概率不同,以及從(cong) 標題可以判斷是否包含答案。
Notebook & 一鍵運行提交: https://aistudio.baidu***.com/aistudio/projectdetail/5013840
評論已經被關(guan) 閉。