์ด์ ๋ถํฐ S2S ๋ชจ๋ธ์ ์ดํ
์
๋ฉ์ปค๋์ฆ์ ๊ฐ์ํํ์ฌ ๊ตฌํํด๋ณด๊ฒ ์ต๋๋ค. ๋ฐ์ดํฐ์
์ ์์ผ๋ก ์ด๋ฃจ์ด์ง ๋ง๋ญ์น๋ก ์์ด ๋ฌธ์ฅ๊ณผ ์ด์ ์์ํ๋ ํ๋์ค์ด ๋ฒ์ญ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. ๋ชจ๋ธ์ ์ธ์ฝ๋๋ ์๋ฐฉํฅ GRU์ ๋์ ์ฌ์ฉํ๊ณ ์ํ์ค์ ์๋ ๋ชจ๋ ๋ถ๋ถ์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์
๋ ฅ ์ํ์ค์ ๊ฐ ์์น์ ๋ํ ๋ฒกํฐ๋ฅผ ๊ณ์ฐํฉ๋๋ค. ํ์ดํ ์น์ PackedSequence ๋ฐ์ดํฐ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๋ฉฐ, ์ดํ์ ๋์ฌ โNMT ๋ชจ๋ธ์ ์ธ์ฝ๋ฉ๊ณผ ๋์ฝ๋ฉโ์์ ์์ธํ ๋ค๋ค๋ณด๊ฒ ์ต๋๋ค.
ย
๊ธฐ๊ณ ๋ฒ์ญ ๋ฐ์ดํฐ์
์ด๋ฒ ์์ ์์๋ ๋ฐ์ดํฐ์
์ผ๋ก ํํ ์๋ฐ ํ๋ก์ ํธ(Tatoeba Project)์ ์์ด-ํ๋์ค์ด ๋ฌธ์ฅ ์์ผ๋ก ๊ตฌ์ฑ๋ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ฐ์ ๋ชจ๋ ๋ฌธ์๋ฅผ ์๋ฌธ์๋ก ๋ณํํ๊ณ , NLTK๋ฅผ ์ด์ฉํ์ฌ ์์ด, ํ๋์ค์ด ํ ํฐํ๋ฅผ ๊ฐ ๋ฌธ์ฅ ์์ ์ ์ฉํฉ๋๋ค. ์ดํ NLTK์ ์ธ์ด์ ํนํ๋ ๋จ์ด ํ ํฐํ๋ฅผ ์ ์ฉํด ํ ํฐ๋ฆฌ์คํธ๋ฅผ ๋ง๋ญ๋๋ค. ๋ฐฉ๊ธ๊น์ง์ ๊ธฐ๋ณธ ์ ์ฒ๋ฆฌ์ ํน์ ๋ฌธ์ฅ ํจํด์ ์ง์ ํ์ฌ ๋ฐ์ดํฐ์ ์ผ๋ถ๋ถ๋ง ์ ํํด ํ์ต ๋ฌธ์ ๋ฅผ ๋จ์ํ๊ฒ ๋ง๋ค ์ ์์ต๋๋ค. ์ ํ๋ ๋ฌธ์ฅ ํจํด์ผ๋ก ๋ฐ์ดํฐ ๋ฒ์๋ฅผ ์ขํ๋ ๊ฒ์
๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ๋ชจ๋ธ์ ๋ถ์ฐ์ ๋ฎ์ถ๊ณ ์งง์ ์๊ฐ ๋ด์ ๋์ ์ฑ๋ฅ ๋ฌ์ฑ์ด ๊ฐ๋ฅํฉ๋๋ค.
ย
NMT๋ฅผ ์ํ ๋ฒกํฐ ํ์ดํ๋ผ์ธ
์์ค ์์ด์ ํ๊น ํ๋์ค์ด๋ฅผ ๋ฒกํฐ๋ก ๋ณํํ๊ธฐ ์ํด์๋ ๋ณต์กํ ํ์ดํ๋ผ์ธ์ด ํ์ํฉ๋๋ค. ๋ณต์ก๋๊ฐ ์ฆ๊ฐํ๋ ์์ธ์๋ ํฌ๊ฒ ๋ ๊ฐ์ง๊ฐ ์์ต๋๋ค. ์ฒซ ๋ฒ์งธ๋ก ์์ค์ ํ๊น ์ํ์ค๋ ๋ชจ๋ธ์์ ๋ค๋ฅธ ์ญํ ์ ํ๊ณ , ์ธ์ด๋ ๋ค๋ฅด๋ฉฐ, ๋ฒกํฐํ๋๋ ๋ฐฉ์๋ ๋ค๋ฅด๊ธฐ ๋๋ฌธ์
๋๋ค. ๋ ์งธ๋ก ํ์ดํ ์น์ PackedSequence๋ฅผ ์ฌ์ฉํ ๋์๋ ์์ค ์ํ์ค์ ๊ธธ์ด์ ๋ฐ๋ผ ๊ฐ ๋ฏธ๋๋ฐฐ์น๋ฅผ ์ํ
ํด์ผํ๊ธฐ ๋๋ฌธ์
๋๋ค. ์ ๋ ๊ฐ์ง ๋ฌธ์ ๋๋ฌธ์ NMTVectorizer๋ ๋ณ๋์ SequenceVocabulary ๊ฐ์ฒด ๋ ๊ฐ๋ฅผ ๋ง๋ค๊ณ , ์ต๋ ์ํ์ค ๊ธธ์ด๋ฅผ ๋ฐ๋ก ์ธก์ ํฉ๋๋ค.
class NMTVectorizer(object): """ ์ดํ ์ฌ์ ์ ์์ฑํ๊ณ ๊ด๋ฆฌํฉ๋๋ค """ def __init__(self, source_vocab, target_vocab, max_source_length, max_target_length): """ ๋งค๊ฐ๋ณ์: source_vocab (SequenceVocabulary): ์์ค ๋จ์ด๋ฅผ ์ ์์ ๋งคํํฉ๋๋ค target_vocab (SequenceVocabulary): ํ๊น ๋จ์ด๋ฅผ ์ ์์ ๋งคํํฉ๋๋ค max_source_length (int): ์์ค ๋ฐ์ดํฐ์ ์์ ๊ฐ์ฅ ๊ธด ์ํ์ค ๊ธธ์ด max_target_length (int): ํ๊น ๋ฐ์ดํฐ์ ์์ ๊ฐ์ฅ ๊ธด ์ํ์ค ๊ธธ์ด """ self.source_vocab = source_vocab self.target_vocab = target_vocab self.max_source_length = max_source_length self.max_target_length = max_target_length @classmethod def from_dataframe(cls, bitext_df): """ ๋ฐ์ดํฐ์ ๋ฐ์ดํฐํ๋ ์์ผ๋ก NMTVectorizer๋ฅผ ์ด๊ธฐํํฉ๋๋ค ๋งค๊ฐ๋ณ์: bitext_df (pandas.DataFrame): ํ ์คํธ ๋ฐ์ดํฐ์ ๋ฐํ๊ฐ : NMTVectorizer ๊ฐ์ฒด """ source_vocab = SequenceVocabulary() target_vocab = SequenceVocabulary() max_source_length = 0 max_target_length = 0 for _, row in bitext_df.iterrows(): source_tokens = row["source_language"].split(" ") if len(source_tokens) > max_source_length: max_source_length = len(source_tokens) for token in source_tokens: source_vocab.add_token(token) target_tokens = row["target_language"].split(" ") if len(target_tokens) > max_target_length: max_target_length = len(target_tokens) for token in target_tokens: target_vocab.add_token(token) return cls(source_vocab, target_vocab, max_source_length, max_target_length)
ย
๋ณต์ก๋๊ฐ ์ฆ๊ฐํ๋ ์ฒซ ๋ฒ์งธ ์์ธ์ผ๋ก ์์ค์ ํ๊น ์ํ์ค๋ฅผ ๋ค๋ฃจ๋ ๋ฐฉ๋ฒ์ด ๋ค๋ฅด๋ค๊ณ ์ค๋ช
ํ์ต๋๋ค. ์์ค ์ํ์ค๋ ์์ ๋ถ๋ถ์ BEGIN-OF-SEQUENCE ํ ํฐ์, ๋ง์ง๋ง์ END-OF-SEQUENCE ํ ํฐ์ ์ถ๊ฐํ๋ฉฐ ๋ฒกํฐํํฉ๋๋ค. ์ด ๋ชจ๋ธ์ ์๋ฐฉํฅ GRU๋ฅผ ์ฌ์ฉํ์ฌ ์์ค ์ํ์ค์ ์๋ ํ ํฐ์ ์ํ ์์ฝ ๋ฒกํฐ๋ฅผ ๋ง๋ญ๋๋ค. ๋ฐ๋ฉด์ ํ๊น ์ํ์ค๋ ํ ํฐ ํ๋๊ฐ ๋ฐ๋ฆฐ ๋ณต์ฌ๋ณธ ๋ ๊ฐ๋ก ๋ฒกํฐํ๋ฉ๋๋ค. ์ํ์ค ์์ธก ์์
์๋ ํ์ ์คํ
๋ง๋ค ์
๋ ฅ ํ ํฐ๊ณผ ์ถ๋ ฅ ํ ํฐ์ด ํ์ํ๋ฐ, S2S์ ๋ชจ๋ธ ๋์ฝ๋๊ฐ ์ด ์์
์ ์ํํ๋ฉด์๋ ์ธ์ฝ๋ ๋ฌธ๋งฅ์ด ์ถ๊ฐ๋ฉ๋๋ค. ์ด ์์
์ ๋จ์ํํ๊ธฐ ์ํด ์์ค์ ํ๊น ์ธ๋ฑ์ค์ ์๊ด์์ด ๋ฒกํฐํ๋ฅผ ์ํํ๋ _vectorize() ๋ฉ์๋๋ฅผ ๋ง๋ค์์ต๋๋ค. ๋ค์์ผ๋ก ์ธ๋ฑ์ค๋ฅผ ๊ฐ๊ธฐ ์ฒ๋ฆฌํ๋ ๋ ๋ฉ์๋๋ฅผ ๋ง๋ญ๋๋ค.
def _vectorize(self, indices, vector_length=-1, mask_index=0): """์ธ๋ฑ์ค๋ฅผ ๋ฒกํฐ๋ก ๋ณํํฉ๋๋ค ๋งค๊ฐ๋ณ์: indices (list): ์ํ์ค๋ฅผ ๋ํ๋ด๋ ์ ์ ๋ฆฌ์คํธ vector_length (int): ์ธ๋ฑ์ค ๋ฒกํฐ์ ๊ธธ์ด mask_index (int): ์ฌ์ฉํ ๋ง์คํฌ ์ธ๋ฑ์ค; ๊ฑฐ์ ํญ์ 0 """ if vector_length < 0: vector_length = len(indices) vector = np.zeros(vector_length, dtype=np.int64) vector[:len(indices)] = indices vector[len(indices):] = mask_index return vector def _get_source_indices(self, text): """ ๋ฒกํฐ๋ก ๋ณํ๋ ์์ค ํ ์คํธ๋ฅผ ๋ฐํํฉ๋๋ค ๋งค๊ฐ๋ณ์: text (str): ์์ค ํ ์คํธ; ํ ํฐ์ ๊ณต๋ฐฑ์ผ๋ก ๊ตฌ๋ถ๋์ด์ผ ํฉ๋๋ค ๋ฐํ๊ฐ: indices (list): ํ ์คํธ๋ฅผ ํํํ๋ ์ ์ ๋ฆฌ์คํธ """ indices = [self.source_vocab.begin_seq_index] indices.extend(self.source_vocab.lookup_token(token) for token in text.split(" ")) indices.append(self.source_vocab.end_seq_index) return indices def _get_target_indices(self, text): """ ๋ฒกํฐ๋ก ๋ณํ๋ ํ๊น ํ ์คํธ๋ฅผ ๋ฐํํฉ๋๋ค ๋งค๊ฐ๋ณ์: text (str): ํ๊น ํ ์คํธ; ํ ํฐ์ ๊ณต๋ฐฑ์ผ๋ก ๊ตฌ๋ถ๋์ด์ผ ํฉ๋๋ค ๋ฐํ๊ฐ: ํํ: (x_indices, y_indices) x_indices (list): ๋์ฝ๋์์ ์ํ์ ๋ํ๋ด๋ ์ ์ ๋ฆฌ์คํธ y_indices (list): ๋์ฝ๋์์ ์์ธก์ ๋ํ๋ด๋ ์ ์ ๋ฆฌ์คํธ """ indices = [self.target_vocab.lookup_token(token) for token in text.split(" ")] x_indices = [self.target_vocab.begin_seq_index] + indices y_indices = indices + [self.target_vocab.end_seq_index] return x_indices, y_indices def vectorize(self, source_text, target_text, use_dataset_max_lengths=True): """ ๋ฒกํฐํ๋ ์์ค ํ ์คํธ์ ํ๊น ํ ์คํธ๋ฅผ ๋ฐํํฉ๋๋ค ๋ฒกํฐํ๋ ์์ค ํ ์ฝํธ๋ ํ๋์ ๋ฒกํฐ์ ๋๋ค. ๋ฒกํฐํ๋ ํ๊น ํ ์คํธ๋ 7์ฅ์ ์ฑ์จ ๋ชจ๋ธ๋ง๊ณผ ๋น์ทํ ์คํ์ผ๋ก ๋ ๊ฐ์ ๋ฒกํฐ๋ก ๋๋ฉ๋๋ค. ๊ฐ ํ์ ์คํ ์์ ์ฒซ ๋ฒ์งธ ๋ฒกํฐ๊ฐ ์ํ์ด๊ณ ๋ ๋ฒ์งธ ๋ฒกํฐ๊ฐ ํ๊น์ด ๋ฉ๋๋ค. ๋งค๊ฐ๋ณ์: source_text (str): ์์ค ์ธ์ด์ ํ ์คํธ target_text (str): ํ๊น ์ธ์ด์ ํ ์คํธ use_dataset_max_lengths (bool): ์ต๋ ๋ฒกํฐ ๊ธธ์ด๋ฅผ ์ฌ์ฉํ ์ง ์ฌ๋ถ ๋ฐํ๊ฐ: ๋ค์๊ณผ ๊ฐ์ ํค์ ๋ฒกํฐํ๋ ๋ฐ์ดํฐ๋ฅผ ๋ด์ ๋์ ๋๋ฆฌ: source_vector, target_x_vector, target_y_vector, source_length """ source_vector_length = -1 target_vector_length = -1 if use_dataset_max_lengths: source_vector_length = self.max_source_length + 2 target_vector_length = self.max_target_length + 1 source_indices = self._get_source_indices(source_text) source_vector = self._vectorize(source_indices, vector_length=source_vector_length, mask_index=self.source_vocab.mask_index) target_x_indices, target_y_indices = self._get_target_indices(target_text) target_x_vector = self._vectorize(target_x_indices, vector_length=target_vector_length, mask_index=self.target_vocab.mask_index) target_y_vector = self._vectorize(target_y_indices, vector_length=target_vector_length, mask_index=self.target_vocab.mask_index) return {"source_vector": source_vector, "target_x_vector": target_x_vector, "target_y_vector": target_y_vector, "source_length": len(source_indices)}
ย
๋ณต์ก๋์ ๋ค์ ์์ธ์ ์์ค ์ํ์ค์
๋๋ค. ์๋ฐฉํฅ GRU๋ก ์์ค ์ํ์ค๋ฅผ ์ธ์ฝ๋ฉํ ๋ ํ์ดํ ์น์ PackedSequence ๋ฐ์ดํฐ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค. ๊ฐ๋ณ ๊ธธ์ด ์ํ์ค์ ๋ฏธ๋๋ฐฐ์น๋ ๊ฐ ์ํ์ค๋ฅผ ํ์ผ๋ก ์์ ์ ์ ํ๋ ฌ๋ก ํํ๋๋ฉฐ, ์ํ์ค๋ ์ผ์ชฝ ์ ๋ ฌ๋๊ณ ์ ๋ก ํจ๋ฉ๋์ด ๊ฐ๋ณ ๊ธธ์ด๋ฅผ ํ์ฉํ๊ฒ ๋ฉ๋๋ค. PackedSequence ๋ฐ์ดํฐ ๊ตฌ์กฐ๋ ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ๊ฐ๋ณ ๊ธธ์ด ์ํ์ค ๋ฏธ๋๋ฐฐ์น๋ฅผ ๋ฐฐ์ด ํ๋๋ก ํํํฉ๋๋ค. ์ํ์ค์ ํ์ ์คํ
๋ฐ์ดํฐ๋ฅผ ์ฐจ๋ก๋๋ก ์ฐ๊ฒฐํ๊ณ ํ์ ์คํ
๋ง๋ค ์ํ์ค ๊ธธ์ด๋ฅผ ๊ธฐ๋กํ๊ฒ ๋ฉ๋๋ค.
ย
PackedSequence๋ฅผ ๋ง๋ค๋ ค๋ฉด ๊ฐ ์ํ์ค์ ๊ธธ์ด๋ฅผ ์์์ผ ํ๋ฉฐ, ์ํ์ค์ ๊ธธ์ด ์์๋๋ก ๋ด๋ฆผ์ฐจ์ ์ ๋ ฌ์ ํด์ผ ํฉ๋๋ค. ์ ๋ ฌ๋ ํ๋ ฌ์ ๋ง๋ค๊ธฐ ์ํด์ ๋ฏธ๋๋ฐฐ์น์ ์๋ ํ
์๋ฅผ ์ํ์ค ๊ธธ์ด ์์๋๋ก ์ ๋ ฌํฉ๋๋ค. ์๋ ์ฝ๋๋ generate_batches()๋ฅผ ์์ ํ generate_nmt_batches() ํจ์์
๋๋ค.
def generate_nmt_batches(dataset, batch_size, shuffle=True, drop_last=True, device="cpu"): """ ํ์ดํ ์น DataLoader๋ฅผ ๊ฐ์ธ๊ณ ์๋ ์ ๋๋ ์ดํฐ ํจ์; NMT ๋ฒ์ """ dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) for data_dict in dataloader: lengths = data_dict['x_source_length'].numpy() sorted_length_indices = lengths.argsort()[::-1].tolist() out_data_dict = {} for name, tensor in data_dict.items(): out_data_dict[name] = data_dict[name][sorted_length_indices].to(device) yield out_data_dict
ย
ย
์ธ์ฝ๋ฉ๊ณผ ๋์ฝ๋ฉ
์์ด๋ฅผ ํ๋์ค์ด๋ก ๋ฒ์ญํ๋ ๊ธฐ๊ณ ๋ณ์ญ์ ์คํํ๊ธฐ ์ํด์๋ ์ธ์ฝ๋-๋์ฝ๋ ๋ชจ๋ธ์ ์ฌ์ฉํฉ๋๋ค. ์ธ์ฝ๋๊ฐ ์๋ฐฉํฅ GRU๋ฅผ ์ฌ์ฉํ์ฌ ์์ค ์ํ์ค(์์ด ๋ฌธ์ฅ)์ ๋ฒกํฐ ์ํ์ ์ํ์ค๋ก ๋งคํํ๋ฉด ๋์ฝ๋๊ฐ ์ธ์ฝ๋์์ ์ถ๋ ฅ๋ ์๋ ์ํ๋ฅผ ์ด๊ธฐ ์๋ ์ํ๋ก ๋ฐ์์ ์ดํ
์
๋ฉ์ปค๋์ฆ์ผ๋ก ์์ค ์ํ์ค๋ฅผ ํ ๋๋ก ์ถ๋ ฅ ์ํ์ค(ํ๋์ค์ด ๋ฒ์ญ๋ฌธ)๋ฅผ ์์ฑํ๋ ๊ณผ์ ์ ๊ฑฐ์นฉ๋๋ค.
๋ค์๊ณผ ๊ฐ์ด ์ ๊ฒฝ๋ง ๊ธฐ๊ณ ๋ฒ์ญ ๋ชจ๋ธ NMTModel์ ์ฝ๋ฉํด๋ด
๋๋ค. NMTModel์ ํ๋์ forward() ๋ฉ์๋(์ ๋ฐฉํฅ ๊ณ์ฐํ๋ ๋ฉ์๋)์ ์ธ์ฝ๋์ ๋์ฝ๋๋ฅผ ์บก์ํํ์ฌ ๊ด๋ฆฌํฉ๋๋ค.
ย
class NMTModel(nn.Module): """ ์ ๊ฒฝ๋ง ๊ธฐ๊ณ ๋ฒ์ญ ๋ชจ๋ธ """ def __init__(self, source_vocab_size, source_embedding_size, target_vocab_size, target_embedding_size, encoding_size, target_bos_index): """ ๋งค๊ฐ๋ณ์: source_vocab_size (int): ์์ค ์ธ์ด์ ๊ณ ์ ํ ๋จ์ด ๊ฐ์ source_embedding_size (int): ์์ค ์๋ฒ ๋ฉ ๋ฒกํฐ์ ํฌ๊ธฐ target_vocab_size (int): ํ๊น ์ธ์ด์ ๊ณ ์ ํ ๋จ์ด ๊ฐ์ target_embedding_size (int): ํ๊น ์๋ฒ ๋ฉ ๋ฒกํฐ์ ํฌ๊ธฐ encoding_size (int): ์ธ์ฝ๋ RNN์ ํฌ๊ธฐ target_bos_index (int): BEGIN-OF-SEQUENCE ํ ํฐ ์ธ๋ฑ์ค """ super(NMTModel, self).__init__() self.encoder = NMTEncoder(num_embeddings = source_vocab_size, embedding_size = source_embedding_size, rnn_hidden_size = encoding_size) decoding_size = encoding_size = 2 self.decoder = NMTDecoder(num_embeddings = target_vocab_size, embedding_size = tarrget_embedding_size, rnn_hidden_size = decoding_size, bos_index = target_bos_index) def forward(self, x_source, x_source_lengths, target_sequence): """ ๋ชจ๋ธ์ ์ ๋ฐฉํฅ ๊ณ์ฐ ๋งค๊ฐ๋ณ์: x_source (torch.Tensor): ์์ค ํ ์คํธ ๋ฐ์ดํฐ ์ผ์ x_source.shape๋ (batch, vectorizer.max_source_length)์ ๋๋ค. x_source_lengths torch.Tensor): x_source์ ์ํ์ค ๊ธธ์ด target_sequence (torch.Tensor): ํ๊น ํ ์คํธ ๋ฐ์ดํฐ ํ ์ ๋ฐํ๊ฐ: decoded_states (torch.Tensor): ๊ฐ ์ถ๋ ฅ ํ์ ์คํ ์ ์์ธก ๋ฒกํฐ """ encoder_state, final_hidden_states = self.encoder(x_source, x_source_lengths) decoded_states = self.decoder(encoder_state = encoder_state, initial_hidden_state = final_hidden_states, target_sequence = target_sequence) return decoded_states
ย
๋ค์์ผ๋ก ์๋ฐฉํฅ GRU๋ฅผ ์ฌ์ฉํ์ฌ ๋จ์ด๋ฅผ ์๋ฒ ๋ฉํ๊ณ ํน์ฑ์ ์ถ์ถํ๋, ์ฆ, ์์ค ์ํ์ค๋ฅผ ๋ฒกํฐ ์ํ๋ก ๋งคํํ๋ ์ธ์ฝ๋ NMTEncoder๋ฅผ ์ฝ๋ฉํด๋ด
๋๋ค. ์ฌ๊ธฐ์ ์ธ์ฝ๋์ ์ถ๋ ฅ์ ์๋ฐฉํฅ GRU์ ์ต์ข
์๋ ์ํ๊ฐ ๋๊ณ ์ด๋ฅผ ์ดํ ๋์ฝ๋๊ฐ ๋ฐ๊ฒ ๋ฉ๋๋ค.
์ฝ๋๋ฅผ ์์ธํ ์ดํด๋ณด๋ฉด ์ฐ์ ์๋ฒ ๋ฉ ์ธต์ ์ฌ์ฉํด ์
๋ ฅ ์ํ์ค๋ฅผ ์๋ฒ ๋ฉํฉ๋๋ค. ์ด๋, ๊ฐ๋ณ ๊ธธ์ด ์ํ์ค๋ padding_idx๋ผ๋ ๋งค๊ฐ๋ณ์๋ฅผ ํตํด ์ฒ๋ฆฌํฉ๋๋ค. padding_idx์ ๋์ผํ ๋ชจ๋ ์์น๋ 0๋ฒกํฐ๊ฐ ๋๋ฉฐ ์ต์ ํ ๊ณผ์ ์ ๊ฑฐ์น ๋ ์
๋ฐ์ดํธ๋์ง ์๋ ๋ง์คํน(masking)์ด ๋๊ธฐ ๋๋ฌธ์ ๊ฐ๋ณ ๊ธธ์ด ์ํ์ค ์ฒ๋ฆฌ๊ฐ ๊ฐ๋ฅํฉ๋๋ค.
๋ค๋ง ์๋ฐฉํฅ GRU๋ ํน๋ณํ ์๋ฐฉํฅ์ผ ๋์ ์ญ๋ฐฉํฅ์ผ ๋์ ๋ง์คํน๋ ์์น๊ฐ ๋ฌ๋ผ์ง ์ ์๊ธฐ ๋๋ฌธ์ ์ธ์ฝ๋-๋์ฝ๋ ๋ชจ๋ธ์์๋ ๋ง์คํน ์์น๋ฅผ ๋ค๋ฅธ ๋ฐฉ์์ผ๋ก, ํ์ดํ ์น์ PackedSequence ๋ฐ์ดํฐ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ์ฌ ์ฒ๋ฆฌํ๋ค๊ณ ํฉ๋๋ค.
ย
class NMTEncoder(nn.Module): def __init__(self, num_embeddings, embedding_size, rnn_hidden_size): """ ๋งค๊ฐ๋ณ์: num_embeddings (int): ์๋ฒ ๋ฉ ๊ฐ์๋ ์์ค ์ดํ ์ฌ์ ์ ํฌ๊ธฐ์ ๋๋ค embedding_size (int): ์๋ฒ ๋ฉ ๋ฒกํฐ์ ํฌ๊ธฐ rnn_hidden_size (int): RNN ์๋ ์ํ ๋ฒกํฐ์ ํฌ๊ธฐ """ super(NMTEncoder, self).__init__() self.source_embedding = nn.Embedding(num_embeddings, embedding_size, padding_idx = 0) self.birnn = nn.GRU(embedding_size, rnn_hidden_size, bidirectional = True, batch_first = True) def forward(self, x_source, x_lengths): x_embedded = self.source_embedding(x_source) # create PackedSequence ์์ฑ x_packed.data.shape = (number_items, # embedding_size) x_lengths = x_lengths.detatch().cpu().numpy() x_packed = pack_padded_sequence(x_embedded, x_lengths, batch_first = True) # x_birnn_h.shape = (num_rnn, batch_size, feature_size) x_birnn_out, x_birnn_h = self.birnn(x_packed) # (batch_size, num_rnn. feature_size)๋ก ๋ณํ x_birnn_h = x_brnn_h.permute(1, 0, 2) # ํน์ฑ ํผ์นจ. (batch_size, num_rnn * feature_size)๋ก ๋ฐ๊พธ๊ธฐ # (์ฐธ๊ณ : -1์ ๋จ์ ์ฐจ์์ ํด๋นํ๋ฉฐ, # ๋ ๊นจ์ RNN ์๋ ๋ฒกํฐ๋ฅผ 1๋ก ํผ์นฉ๋๋ค) x_birnn_h = x_birnn_h.contiguous().view(x_birnn_h.size(0), -1) x_unpacked, _ = pad_packed_sequence(x_birnn_out, batch_first = True) return x_unpacked, x_birnn_h
ย
์ด์ ์ธ์ฝ๋์ ์ถ๋ ฅ์ธ ์ต์ข
์๋ ์ํ๋ฅผ ๋์ฝ๋ NMTDecoder์ด ๋ฐ์ ํ์ ์คํ
์ ์ํํ๋ฉด์ ์ถ๋ ฅ ์ํ์ค๋ฅผ ์์ฑํฉ๋๋ค.
์ด ์์ ๋ ํ๊น ์ํ์ค๊ฐ ํ์ ์คํ
๋ง๋ค ์ํ๋ก ์ ๊ณต๋๋ค๋ ํน์ด์ ์ด ์์ต๋๋ค. GRUCell์ ์ฌ์ฉํด ์๋ ์ํ๋ฅผ ๊ณ์ฐํ๋ฉด ์ธ์ฝ๋์ ์ต์ข
์๋ ์ํ์ Linear์ธต์ ์ ์ฉํ์ฌ ์ด๊ธฐ ์๋ ์ํ๋ฅผ ๊ณ์ฐํ๋๋ฐ, ์ด๋ ๋์ฝ๋ GRU๋ ์๋ฒ ๋ฉ๋ ์
๋ ฅ ํ ํฐ๊ณผ ๋ง์ง๋ง ํ์ ์คํ
์ ๋ฌธ๋งฅ ๋ฒกํฐ๋ฅผ ์ฐ๊ฒฐํ ๋ฒกํฐ๋ฅผ ์
๋ ฅ ๋ฐ๋๋ค๊ณ ํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ฟผ๋ฆฌ ๋ฒกํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๊ทธ ์๋ก์ด ์
๋ ฅ ๋ฒกํฐ๋ฅผ ํ์ฌ ํ์ ์คํ
์์ ์ดํ
์
๋ฉ์ปค๋์ฆ์ผ๋ก ์๋ก์ด ๋ฌธ๋งฅ ๋ฒกํฐ๋ฅผ ๋ง๋ ํ ์๋ ์ํ์ ์ฐ๊ฒฐํ์ฌ ๋์ฝ๋ฉ ์ ๋ณด๋ฅผ ํํํ๋ ๋ฒกํฐ๋ฅผ ๋ง๋ญ๋๋ค. ์ด ๋ฒกํฐ๋ฅผ ์ด์ฉํ์ฌ ๋ถ๋ฅ๊ธฐ(๊ฐ๋จํ Linear์ธต)๊ฐ ์์ธก ๋ฒกํฐ score_for_y_t_index๋ฅผ ์์ฑํ๊ณ ์ํํธ๋งฅ์ค ํจ์๋ฅผ ์ด์ฉํ์ฌ ์์ธก ๋ฒกํฐ๋ฅผ ์์ฑํฉ๋๋ค.
ย
class NMTDecoder(nn.Module): def __init__(self, num_embeddings, embedding_size, rnn_hidden_size, bos_index): """ ๋งค๊ฐ๋ณ์: num_embeddings (int): ์๋ฒ ๋ฉ ๊ฐ์๋ ํ๊น ์ดํ ์ฌ์ ์ ๊ณ ์ ํ ๋จ์ด์ ๊ฐ์์ ๋๋ค embeddin_size (int): ์๋ฒ ๋ฉ ๋ฒกํฐ ํฌ๊ธฐ rnn_hidden_size (int): RNN ์๋ ์ํ ํฌ๊ธฐ bos_index (int): begin-of-sequence ์ธ๋ฑ์ค """ super(NMTDecoder, self).__init__() self._rnn_hidden_size = rnn_hidden_size self.target_embedding = nn.Embedding(num_embeddings = num_embeddings, embedding_dim = embedding_size, padding_ide = 0) self.gru_cell = nn.GRUCell(embedding_size + rnn_hidden_size, rnn_hidden_size) self.hidden_map = nn.Linear(rnn_hidden_size, rnn_hidden_size) self.classifie = nn.Linear(rnn_hidden_size * 2, num_embeddings) self.bos_index = bos_index def _init_indices(self, batch_size): """ BEGIN-OF-SEQUENCE ์ธ๋ฑ์ค ๋ฒกํฐ๋ฅผ ๋ฐํํฉ๋๋ค """ return torch.ones(batch_size, dtype = torch.int64) * self.bos_index def _init_context_vectors(self, batch_size): """ ๋ฌธ๋งฅ ๋ฒกํฐ๋ฅผ ์ด๊ธฐํํ๊ธฐ ์ํ 0 ๋ฒกํฐ๋ฅผ ๋ฐํํฉ๋๋ค """ return torch.zeros(batch_size, self._rnn_hidden_size) def forward(self, encoder_state, initial_hidden_state, target_sequence): """ """ # ๊ฐ์ : ์ฒซ ๋ฒ์งธ ์ฐจ์์ ๋ฐฐ์น ์ฐจ์์ ๋๋ค # ์ฆ ์ ๋ ฅ์ (Batch, Seq) # ์ํ์ค์ ๋ํด ๋ฐ๋ณตํด์ผ ํ๋ฏ๋ก (Seq, Batch)๋ก ์ฐจ์์ ๋ฐ๊ฟ๋๋ค target_sequence = target_sequence.permute(1,0) # ์ฃผ์ด์ง ์ธ์ฝ๋์ ์๋ ์ํ๋ฅผ ์ด๊ธฐ ์๋ ์ํ๋ก ์ฌ์ฉํฉ๋๋ค h_t = self.hidden_map(initial_hidden_state) batch_size = encoder_state_size(0) # ๋ฌธ๋งฅ ๋ฒกํฐ๋ฅผ 0์ผ๋ก ์ด๊ธฐํํฉ๋๋ค context_vectors - self._init_context_vectors(batch_size) # ์ฒ ๋จ์ด y_t๋ฅผ BOS๋ก ์ด๊ธฐํํฉ๋๋ค y_t_index = self._init_indeices(batch_size) h_t = h_t.to(encoder_state.device) y_index = y_t_index.to(encoder_state.device) context_vectors = context_vectors.to(encoder_state.device) output_vectors = [] #๋ถ์์ ์ํด GPU์์ ์บ์ฑ๋ ๋ชจ๋ ํ ์๋ฅผ ๊ฐ์ ธ์ ์ ์ฅํฉ๋๋ค self._cached_p_attn = [] self._cached_ht = [] self._cached_decoder_state = encoder_state.cpu().detatch().numpy() output_sequence_size = target_sequence_size(0) for i in range(output_sequence_size): # 1๋จ๊ณ: ๋จ์ด๋ฅผ ์๋ฒ ๋ฉํ๊ณ ์ด์ ๋ฌธ๋งฅ๊ณผ ์ฐ๊ฒฐํฉ๋๋ค y_input_vector = self.target_embedding(target_sequence[i]) rnn_input = torch.cat([y_input_vector, context_vectors], dim = 1) # 2๋จ๊ณ: GRU๋ฅผ ์ ์ฉํ๊ณ ์๋ก์ด ์๋ ๋ฒกํฐ๋ฅผ ์ป์ต๋๋ค h_t = self.gru_cell(rnn_input, h_t) self._cached_ht.append(h_t.cpu().data.numpy()) # 3๋จ๊ณ: ํ์ฌ ์๋ ์ํ๋ฅผ ์ฌ์ฉํด ์ธ์ฝ๋์ ์ํ๋ฅผ ์ฃผ๋ชฉํฉ๋๋ค context_vectors, p_attn, _ = \ verbose_attention(encoder_state_vectors = encoder_state, query_vector = h_t) # ๋ถ๊ฐ ์์ : ์๊ฐํ๋ฅผ ์ํด ์ดํ ์ ํ๋ฅ ์ ์ ์ฅํฉ๋๋ค self._cached_p_attn.append(p_attn.cpu().detatch().numpy()) # 4๋จ๊ณ: ํ์ฌ ์๋ ์ํ์ ๋ฌธ๋งฅ ๋ฒกํฐ๋ฅผ ์ฌ์ฉํด ๋ค์ ๋จ์ด๋ฅผ ์์ธกํฉ๋๋ค prediction_vector = torch.cat((context_vectors, h_t), dim = 1) score_for_y_t_index = self.classifier(prediction_vector) # ๋ถ๊ฐ ์์ : ์์ธก ์ฑ๋ฅ ์ ์๋ฅผ ๊ธฐ๋กํฉ๋๋ค output_vectors.append(score_for_y_t_index)
ย
ย
์ดํ ์ ๋ฉ์ปค๋์ฆ ์์ธํ ์์๋ณด๊ธฐ
์ด ์์ ์์๋ ์ดํ
์
๋ฉ์ปค๋์ฆ์ ๋์์ ์ดํดํ๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค. ์ด์ ๊ธ ์์ ์ค๋ช
ํ๋ ์ดํ
์
๋ฉ์ปค๋์ฆ์ ๋ชจ๋ธ์ ๋ค์ ํ ๋ฒ ๊ฐ์ ธ์ ์์์ ์ผ๋ก ์ดํด๋ณด๊ฒ ์ต๋๋ค.
ย
ate๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํ ๋์ฝ๋ RNN์ ์
๋ ฅ์ผ๋ก ๋ค์ด์ค๋ ๋์ฝ๋ ์๋ ์ํ ๊ฐ์ ์ฟผ๋ฆฌ(Query)๋ผ๊ณ ๋ถ๋ฅด๊ณ , ์ธ์ฝ๋ RNN์ ๊ฐ ์ถ๋ ฅ๊ฐ๋ค์ ํค(Key), ๊ฐ(Value)๋ผ๊ณ ๋ถ๋ฆ
๋๋ค. ์ด๋ค ๋จ์ด๋ค์ ์ง์คํ ์ง๋ฅผ ๋ํ๋ด๋ ๊ฐ์ธ ์ดํ
์
๊ฐ(Attention Score)๋ฅผ ๊ตฌํ๊ธฐ ์ํด, 8-3์์ ์ค๋ช
ํ๋ ์๋์ง ๊ฐ(Energy - ๋จ์ด๋ผ๋ฆฌ ์ผ๋ง๋ ์ฐ๊ด์ฑ์ด ์๋๊ฐ)๋ฅผ ๊ตฌํด์ผ ํฉ๋๋ค. ์ด๋ฅผ ์ํด, ๋์ฝ๋์ ์๋ ๊ฐ์ธ ์ฟผ๋ฆฌ์ ์ธ์ฝ๋์ ๊ฐ ์๋ ๊ฐ์ธ ํค๋ฅผ ๋ด์ ํด์ค๋๋ค. ์ฌ๊ธฐ์์ ์ฐ์ฐ ๊ฒฐ๊ณผ๋ก ์ค์นผ๋ผ ๊ฐ์ ์ป๊ธฐ ์ํด, ๋์ฝ๋ ๊ฐ์ ์ ์นํ์ฌ ๋ด์ ์ ์ํํฉ๋๋ค.
ย
๊ตฌํ ์ผ๋ จ์ ์ค์นผ๋ผ๊ฐ๋ค์ 0๊ณผ 1์ฌ์ด์ ๊ฐ์ด๋ฉด์ ์ด ํฉ์ด 1์ธ ํ๋ฅ ๋ถํฌ๋ก ๋ณํํ๊ธฐ ์ํด ์ํํธ๋งฅ์ค ํจ์๋ฅผ ์ ์ฉํด์ค๋๋ค. ์ํํธ๋งฅ์ค ํจ์๋ ์ฃผ์ด์ง ๊ฐ๋ค์ ๋น์จ์ ์ ์งํ๋ฉด์ ์ด ํฉ์ด 1์ด ๋๋๋ก ๋ง๋ค์ด์ฃผ๋, ํ๋ฅ ๋ถํฌ๋ก ๋ณํํ๊ธฐ ์ํด ์ฌ์ฉํ๋ ํจ์์
๋๋ค. ์ด ๊ฐ์ด ์ ์ด๋ฏธ์ง์์๋ ๊ฐ์ค์น๋ก ํํ๋์ด์์ต๋๋ค.
ย
๋ค์์ ๋ค๋ฃฐ ํ์ด์ง๋ง, ์ค๋ช
์ ๋์์ ์ฃผ๊ธฐ ์ํด ์์ผ๋ก ๊ฐ์ ธ์์ต๋๋ค. ๋์ฝ๋์ ์ดํ
์
ํ๋ฅ ๋ถํฌ๋ฅผ ์๋ ํ์์ ์ดํด๋ณผ ์ ์์ต๋๋ค. ์ด๋, ์์ค ๋ฌธ์ฅ๊ณผ ๋ฒ์ญ๋ ๋ฌธ์ฅ์ ๊ฐ ๋จ์ด๊ฐ์ ๋์ ๊ด๊ณ๋ฅผ ๊ฐ๊ณ ์๋ ๋จ์ด๋ผ๋ฆฌ ๋์ ํ๋ฅ ๋ถํฌ๋ฅผ ๋ณด์ฌ์ค๋๋ค. ํ๋ฅ ๋ถํฌ๊ฐ ๋์ ๋จ์ด์ ๊ฐ์ ๋ ๋ง์ ๊ฐ์ค์น๋ฅผ ์ฃผ๊ณ , ๋จ์ด ๋ฒ์ญํ ๋ ๊ด๋ จ์ฑ์ด ๋์ ๋จ์ด์ ์ด์ ์ ๋ง์ถ๋ค๊ณ ์๊ฐํ ์ ์์ต๋๋ค.
ย
๊ฐ ์ธ์ฝ๋์ ์๋ ๊ฐ๊ณผ ์์ ๊ตฌํ ๊ฐ์ค์น๋ฅผ ๊ณฑํ ๋ค, ๊ณฑํ ๊ฐ๋ค์ ํฉ์ธ ๊ฐ์ค ํฉ(Weighted Sum)์ ๊ตฌํฉ๋๋ค. ์ด๋ฅผ ํตํด ์ดํ
์
๊ฐ (Attention Score, Attention Value,) ๋๋ ๋ฌธ๋งฅ์ ๋ด๊ณ ์๋ ๋ฒกํฐ์ธ ๋ฌธ๋งฅ ๋ฒกํฐ(Context Vector)๋ผ๊ณ ๋ถ๋ฆ
๋๋ค. ์ด ์ดํ
์
๊ฐ์ ํตํด, ๋ชจ๋ธ์ ๊ฐ ๋จ์ด๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํด ์ด๋ค ๋จ์ด์ ์ง์คํด์ผ ๋ ์ข์ ์ฑ๋ฅ์ ์ป์ ์ ์๋์ง๋ฅผ ํ์ตํ ์ ์์ต๋๋ค.
ย
ย
์๋๋ ์ดํ
์
๋ฉ์ปค๋์ฆ์ ์ฝ๋๋ก ๊ตฌํํ ๊ฒ์
๋๋ค. ์ฒซ ๋ฒ์งธ ์ดํ
์
ํจ์์ธ
verbose_attention
์ ์์์ ์ค๋ช
ํ ์ดํ
์
๋ฉ์ปค๋์ฆ์ ํ ์ค์ ํ๋์ฉ ์์ธํ๊ฒ ์ค๋ช
ํด๋ ๊ฒ์ด๊ณ , ๋ ๋ฒ์งธ terse_attention
์ matmul
์ ์ฌ์ฉํด ์กฐ๊ธ ๋ ํจ์จ์ ์ผ๋ก ์ฐ์ฐํ๋ ๊ณผ์ ์ ๊ตฌํํ ํจ์์
๋๋ค. def verbose_attention(encoder_state_vectors, query_vector): """ ์์๋ณ ์ฐ์ฐ์ ์ฌ์ฉํ๋ ์ดํ ์ ๋ฉ์ปค๋์ฆ ๋ฒ์ ๋งค๊ฐ๋ณ์: encoder_state_vectors (torch.Tensor): ์ธ์ฝ๋์ ์๋ฐฉํฅ GRU์์ ์ถ๋ ฅ๋ 3์ฐจ์ ํ ์ query_vector (torch.Tensor): ๋์ฝ๋ GRU์ ์๋ ์ํ """ batch_size, num_vectors, vector_size = encoder_state_vectors.size() vector_scores = torch.sum(encoder_state_vectors * query_vector.view(batch_size, 1, vector_size), dim=2) # ์ฟผ๋ฆฌ ๊ฐ๊ณผ ๊ฐ ์ธ์ฝ๋ RNN์ ์๋ ๋ฒกํฐ๋ค์ ๊ณฑ vector_probabilities = F.softmax(vector_scores, dim=1) # ๊ฐ๋ค์ 0~1์ ๊ฐ์ผ๋ก ํ๋ฅ ๋ถํฌํ๋ก ๊ฐ์ค์น ๊ตฌํ๊ธฐ weighted_vectors = encoder_state_vectors * vector_probabilities.view(batch_size, num_vectors, 1) # ์๋ ๋ฒกํฐ๊ฐ๊ณผ ๊ฐ์ค์น์ ๊ณฑ ๊ตฌํ๊ธฐ context_vectors = torch.sum(weighted_vectors, dim=1) # ์ด ๊ฐ๋ค์ ํฉ์ผ๋ก ์ปจํ ์คํธ ๋ฒกํฐ ๋๋ ์ดํ ์ ๊ฐ ๊ตฌํ๊ธฐ return context_vectors, vector_probabilities, vector_scores def terse_attention(encoder_state_vectors, query_vector): """ ์ ๊ณฑ์ ์ฌ์ฉํ๋ ์ดํ ์ ๋ฉ์ปค๋์ฆ ๋ฒ์ ๋งค๊ฐ๋ณ์: encoder_state_vectors (torch.Tensor): ์ธ์ฝ๋์ ์๋ฐฉํฅ GRU์์ ์ถ๋ ฅ๋ 3์ฐจ์ ํ ์ query_vector (torch.Tensor): ๋์ฝ๋ GRU์ ์๋ ์ํ """ vector_scores = torch.matmul(encoder_state_vectors, query_vector.unsqueeze(dim=2)).squeeze() # matmul ํจ์๋ฅผ ์ด์ฉํด ์ฟผ๋ฆฌ*ํค ๊ฐ ๊ตฌํ๊ธฐ vector_probabilities = F.softmax(vector_scores, dim=-1) # ๊ฐ์ค์น ๊ตฌํ๊ธฐ context_vectors = torch.matmul(encoder_state_vectors.transpose(-2, -1), vector_probabilities.unsqueeze(dim=2)).squeeze() # ์ปจํ ์คํธ ๋ฒกํฐ ๊ฐ ๊ตฌํ๊ธฐ return context_vectors, vector_probabilities
ย
์ค์ผ์ค๋ง๋ ์ํ๋ง Scheduled Sampling
ํ์ต ๊ณผ์ ์์ ์ฃผ์ด์ง๋ ๋ฐ์ดํฐ์์๋ ํ๊ฒ ์ํ์ค๊ฐ ์ ๊ณต๋๊ณ , ์ด๋ฅผ ์ด์ฉํด ๊ฐ ํ์์คํ
๋ง๋ค ์ฐ์ฐ๊ณผ ํ์ต์ ์งํํ์ง๋ง, ์ค์ ๋ฐ์ดํฐ ํน์ ํ
์คํธ ๋ฐ์ดํฐ์์๋ ๋ชจ๋ธ์ด ๋ง๋๋ ์ํ์ค๊ฐ ์ด๋ค์ง ์ ์ ์๊ธฐ ๋๋ฌธ์, ์ด๋ฌํ ๊ณผ์ ์ด ์๋ํ์ง ์์ ์ ์์ต๋๋ค. ์ฆ, ํ์ต ์์๋ ํ๊ฒ ์ํ์ค๊ฐ ์์ง๋ง, ํ
์คํธ์์๋ ํ๊ฒ ์ํ์ค๊ฐ ์์ด ๊ฐ์ค์น๊ฐ ํฌ๊ฒ ๋ฒ์ด๋๋ ๋ฌธ์ ๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค.
ย
์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด, ์ํ๋ง ๊ธฐ๋ฒ์ ํตํด ํ์ต ๊ณผ์ ์์๋ ์ผ๋ถ ํ๊ฒ ์ํ์ค๋ฅผ ๋ชจ๋ธ์๊ฒ ๋งก๊ธฐ๋ ๋ฐฉ๋ฒ์ ํ์ฉํฉ๋๋ค. ๊ฐ๋จํ๊ฒ ๋งํ๋ฉด, ์ํ๋ง ๊ธฐ๋ฒ์ ๋ฐ์ดํฐ ์ค ์ผ๋ถ๋ง ๋ฝ์์ ํ์ฉํ๋ค๋ ์๋ฏธ์
๋๋ค. ํ์ต ๊ณผ์ ์์ ์ฃผ์ด์ง ํ๊ฒ ์ํ์ค์ ๋ชจ๋ธ์ด ์์ฒด ์์ฑํ ์ํ์ค๋ฅผ ๋ฌด์์๋ก ์ฌ์ฉํ๋ฉฐ, ๋ชจ๋ธ์ด ๊ฒฐ์ ํ๋ ์ํ์ค์ ํ๋ฅ ๋ถํฌ๊ฐ ๊ฐ์ ๋๋๋ก ๋ชจ๋ธ์ ํ์ต์ํต๋๋ค.
ย
์ด๋ฅผ ์ํด ์ํ๋ง ํ๋ฅ ์ ๋จผ์ ์ค์ ํด๋ก๋๋ค. ์ด๊ธฐ ์ธ๋ฑ์ค๋ฅผ ์์ ํ ํฐ์ธ BEGIN ์ธ๋ฑ์ค๋ก ๋จผ์ ์ง์ ํ๊ณ ์์ํฉ๋๋ค. ๊ทธ ๋ค์, ์ํ์ค ์์ฑ๋ฌธ์ ๋ฐ๋ณตํ ๋ ๋๋คํ ๊ฐ(๋์)๋ฅผ ๋ฐ์์ํค๊ณ , ํ๋ฅ ๋ณด๋ค ์์ผ๋ฉด ๋ชจ๋ธ์ ์์ธก ์ํ์ค๋ฅผ ์ฌ์ฉํ๊ณ , ํ๋ฅ ๋ณด๋ค ํฌ๋ฉด ์ฃผ์ด์ง ํ๊ฒ ์ํ์ค๋ฅผ ์ฌ์ฉํ๋ฉด์ ์์ธก์ ์ํํ๊ณ ํ์ต์ ์งํํฉ๋๋ค.
ย
์๋๋ ์ค์ผ์ค๋ง๋ ์ํ๋ง์ ์์ ์ ์ํ๋ NMTDecoder์
forward
ํจ์๋ฅผ ์์ ํด ์์ฑํ ๋ด์ฉ์
๋๋ค. sample_probability
๊ฐ์ ์ง์ ํด์ค ๋, 0์ ๊ฐ๊น์ธ์๋ก ์์ธก ์ํ์ค๋ฅผ ์ฌ์ฉํ๊ณ , 1์ ๊ฐ๊น์ธ์๋ก ํ๊ฒ ์ํ์ค๋ฅผ ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค. ๋, ๋ฐ๋ณต๋ฌธ๋ง๋ค use_sample = np.random.random() < sample_probability
๋ผ๋ ์ฝ๋๋ฅผ ์ฌ์ฉํ์ฌ, ๋์ ๊ฐ์ ๊ธฐ์ค์ผ๋ก ํ๋ฅ ์ ์๊ฑฐํด ๋ชจ๋ธ ์์ธก ์ํ์ค๋ฅผ ์ฌ์ฉํ ์ง, ์๋๋ฉด ํ๊ฒ ์ํ์ค๋ฅผ ์ฌ์ฉํ ์ง ๊ฒฐ์ ํ๊ฒ ๋ฉ๋๋ค. ย
class NMTDecoder(nn.Module): def __init__(self, num_embeddings, embedding_size, rnn_hidden_size, bos_index): super(NMTDecoder, self).__init__() # ์์ ์์ฑํ ์ฝ๋์ ๋์ผํ ์ ์ # ์๋ต def forward(self, encoder_state, initial_hidden_state, target_sequence, sample_probability=0.0): """ ๋ชจ๋ธ์ ์ ๋ฐฉํฅ ๊ณ์ฐ ๋งค๊ฐ๋ณ์: encoder_state (torch.Tensor): NMTEncoder์ ์ถ๋ ฅ initial_hidden_state (torch.Tensor): NMTEncoder์ ๋ง์ง๋ง ์๋ ์ํ target_sequence (torch.Tensor): ํ๊น ํ ์คํธ ๋ฐ์ดํฐ ํ ์ sample_probability (float): ์ค์ผ์ค๋ง๋ ์ํ๋ง ํ๋ผ๋ฏธํฐ ๋์ฝ๋ ํ์ ์คํ ๋ง๋ค ๋ชจ๋ธ ์์ธก์ ์ฌ์ฉํ ํ๋ฅ ๋ฐํ๊ฐ: output_vectors (torch.Tensor): ๊ฐ ํ์ ์คํ ์ ์์ธก ๋ฒกํฐ """ # ์ํ ํ๋ฅ ์ ์ง์ ํด์ค๋๋ค # 0๊ณผ 1 ์ฌ์ด์ ์ ๋นํ ๊ฐ์ ์ง์ ํด์ค๋๋ค # 0 : ์์ธก ์ํ์ค๋ง์ ์ฌ์ฉ # 1 : ํ๊ฒ ์ํ์ค๋ง์ ์ฌ์ฉ if target_sequence is None: sample_probability = 0.5 else: # ๊ฐ์ : ์ฒซ ๋ฒ์งธ ์ฐจ์์ ๋ฐฐ์น ์ฐจ์์ ๋๋ค # ์ฆ ์ ๋ ฅ์ (Batch, Seq) # ์ํ์ค์ ๋ํด ๋ฐ๋ณตํด์ผ ํ๋ฏ๋ก (Seq, Batch)๋ก ์ฐจ์์ ๋ฐ๊ฟ๋๋ค target_sequence = target_sequence.permute(1, 0) output_sequence_size = target_sequence.size(0) # ์ฃผ์ด์ง ์ธ์ฝ๋์ ์๋ ์ํ๋ฅผ ์ด๊ธฐ ์๋ ์ํ๋ก ์ฌ์ฉํฉ๋๋ค h_t = self.hidden_map(initial_hidden_state) batch_size = encoder_state.size(0) # ๋ฌธ๋งฅ ๋ฒกํฐ๋ฅผ 0์ผ๋ก ์ด๊ธฐํํฉ๋๋ค context_vectors = self._init_context_vectors(batch_size) # ์ฒซ ๋จ์ด y_t๋ฅผ BOS๋ก ์ด๊ธฐํํฉ๋๋ค y_t_index = self._init_indices(batch_size) h_t = h_t.to(encoder_state.device) y_t_index = y_t_index.to(encoder_state.device) context_vectors = context_vectors.to(encoder_state.device) output_vectors = [] self._cached_p_attn = [] self._cached_ht = [] self._cached_decoder_state = encoder_state.cpu().detach().numpy() # ๋ฐ๋ณต๋ฌธ ์ด์ ๊น์ง๋ ๊ธฐ์กด์ Decoder ์ฝ๋์ ๋์ผํฉ๋๋ค for i in range(output_sequence_size): # ์ค์ผ์ค๋ง๋ ์ํ๋ง ์ฌ์ฉ ์ฌ๋ถ # ์์ฑํ ๋์์ ํ๋ฅ ๊ฐ์ ๋น๊ตํด ์ํ๋ง ์ฌ์ฉ ์ ๋ฌด๋ฅผ ์ ํํฉ๋๋ค use_sample = np.random.random() < sample_probability if not use_sample: y_t_index = target_sequence[i] # ๋จ๊ณ 1: ๋จ์ด๋ฅผ ์๋ฒ ๋ฉํ๊ณ ์ด์ ๋ฌธ๋งฅ๊ณผ ์ฐ๊ฒฐํฉ๋๋ค y_input_vector = self.target_embedding(y_t_index) rnn_input = torch.cat([y_input_vector, context_vectors], dim=1) # ๋จ๊ณ 2: GRU๋ฅผ ์ ์ฉํ๊ณ ์๋ก์ด ์๋ ๋ฒกํฐ๋ฅผ ์ป์ต๋๋ค h_t = self.gru_cell(rnn_input, h_t) self._cached_ht.append(h_t.cpu().detach().numpy()) # ๋จ๊ณ 3: ํ์ฌ ์๋ ์ํ๋ฅผ ์ฌ์ฉํด ์ธ์ฝ๋์ ์ํ๋ฅผ ์ฃผ๋ชฉํฉ๋๋ค context_vectors, p_attn, _ = verbose_attention(encoder_state_vectors=encoder_state, query_vector=h_t) # ๋ถ๊ฐ ์์ : ์๊ฐํ๋ฅผ ์ํด ์ดํ ์ ํ๋ฅ ์ ์ ์ฅํฉ๋๋ค self._cached_p_attn.append(p_attn.cpu().detach().numpy()) # ๋จ๊ฒ 4: ํ์ฌ ์๋ ์ํ์ ๋ฌธ๋งฅ ๋ฒกํฐ๋ฅผ ์ฌ์ฉํด ๋ค์ ๋จ์ด๋ฅผ ์์ธกํฉ๋๋ค prediction_vector = torch.cat((context_vectors, h_t), dim=1) score_for_y_t_index = self.classifier(F.dropout(prediction_vector, 0.3)) if use_sample: p_y_t_index = F.softmax(score_for_y_t_index * self._sampling_temperature, dim=1) # _, y_t_index = torch.max(p_y_t_index, 1) y_t_index = torch.multinomial(p_y_t_index, 1).squeeze() # ๋ถ๊ฐ ์์ : ์์ธก ์ฑ๋ฅ ์ ์๋ฅผ ๊ธฐ๋กํฉ๋๋ค output_vectors.append(score_for_y_t_index) output_vectors = torch.stack(output_vectors).permute(1, 0, 2) return output_vectors
ย
์ด ๊ณผ์ ์ ํตํด, ๋ชจ๋ธ์ ์ํ์ค๋ฅผ ์์ธกํ๋ ๋ฐฉ๋ฒ์ ํ์ตํ์ฌ, ๋ชจ๋ธ์ด ์์ธกํ ์ํ์ค์ ๊ฐ์ค์น๊ฐ ์ ๋ต์์ ํฌ๊ฒ ๋ฒ์ด๋๋ ๊ฒฝ์ฐ๋ฅผ ์ค์ฌ์ค๋๋ค.
ย
๋ชจ๋ธ ํ๋ จ
์ด๋ฒ ์ฅ์์ ๋ค๋ฃฌ ๋ชจ๋ธ์ ํ๋ จ ๊ณผ์ ์ ์์ 6์ฅ๊ณผ 7์ฅ์์ ๋ค๋ฃฌ ๋ชจ๋ธ์ ํ๋ จ ๊ณผ์ ๊ณผ ๋น์ทํฉ๋๋ค.
- ์์ค ์ํธ์ค์ ํ๊ฒ ์ํ์ค๋ฅผ ์ ๋ ฅ๋ฐ์, ํ๊ฒ ์ํ์ค ์์ธก ์์ฑ
- ํ๊ฒ ์ํ์ค ์์ธก ๋ ์ด๋ธ์ ํตํด ํฌ๋ก์ค ์ํธ๋กํผ ์์ค ๊ณ์ฐ ํฌ๋ก์ค ์ํธ๋กํผ ์์ค์ ํ๋ฅ ๋ถํฌ์ ์์ธก ๋ถํฌ ์ฌ์ด์ ์ฐจ์ด๋ฅผ ๊ณ์ฐํ์ฌ, ๋ถ๋ฅ ๋ชจ๋ธ์ด ์์ธก์ ์ ์ํํ๋์ง๋ฅผ ํ๊ฐํ๋ ์งํ์ ๋๋ค.
- ์ญ์ ํ๋ฅผ ํตํด ๊ทธ๋๋์ธํธ๋ฅผ ๊ณ์ฐ
- ์ตํฐ๋ง์ด์ ๋ฅผ ํตํด ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธ
ย
์ ๊ณผ์ ์ผ๋ก ํ๋ จํ ๋ชจ๋ธ์ ๋ํด, ์์ค ๋ฌธ์ฅ๊ณผ ๋ชจ๋ธ์ ์ํด ์์ฑ๋ ๋ฌธ์ฅ ์์ ๋ํด BLEU ์งํ๋ก ํ๊ฐํ์ฌ ๋ชจ๋ธ์ด ์ผ๋ง๋ ์ ์๋ํ๋์ง ํ์ธํฉ๋๋ค. ์ฐ๋ฆฌ๋ ์์์ 2๊ฐ์ง ๋ชจ๋ธ์ ์ดํด๋ดค์ต๋๋ค. 1) ์ ๊ณต๋ ํ๊ฒ ์ํ์ค๋ฅผ ๋ฐํ์ผ๋ก ๋์ฝ๋์ ์
๋ ฅํ๋ ๋ชจ๋ธ. 2) ์ค์ผ์ค๋ง๋ ์ํ๋ง์ ํตํด ์์ฒด ์์ธก์ ๋ง๋ค์ด ๋์ฝ๋์ ์
๋ ฅํ๋ ๋ชจ๋ธ. ์ด ์ค์์ 2๋ฒ์งธ ๋ชจ๋ธ, ์์ฒด ์์ธก์ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ ์ํ์ค ์์ธก์ ๋ํ ์ค๋ฅ๋ฅผ ์ต์ ํํ๋๋ก ํ๋ ์ฅ์ ์ ๊ฐ๊ณ ์์ต๋๋ค. ๋ ๋ชจ๋ธ์ BLEU ์ ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค. 1๋ฒ ๋ชจ๋ธ๋ณด๋ค 2๋ฒ ๋ชจ๋ธ์์ ์กฐ๊ธ ๋ ๋์ ์ ์๋ฅผ ๋ณด์ด๊ณ ์์์ ํ์ธํ ์ ์์ต๋๋ค.
๋ชจ๋ธ | BLEU |
์ ๊ณต๋ ํ๊ฒ ์ํ์ค๋ฅผ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ ๋ชจ๋ธ | 46.8 |
์ค์ผ์ค๋ง๋ ์ํ๋ง์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ ๋ชจ๋ธ | 48.1 |
ย
์ด์ ๊ธ ์ฝ๊ธฐ
ย