1
- '''
2
- code by Tae Hwan Jung(Jeff Jung) @graykode
3
- '''
1
+ # code by Tae Hwan Jung(Jeff Jung) @graykode
4
2
import numpy as np
5
3
import torch
6
4
import torch .nn as nn
10
8
# S: Symbol that shows starting of decoding input
11
9
# E: Symbol that shows starting of decoding output
12
10
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
11
+ sentences = ['ich mochte ein bier P' , 'S i want a beer' , 'i want a beer E' ]
13
12
14
- char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz ' ]
15
- num_dic = {n : i for i , n in enumerate (char_arr )}
16
- dic_len = len (num_dic )
13
+ word_list = " " .join (sentences ).split ()
14
+ word_list = list (set (word_list ))
15
+ word_dict = {w : i for i , w in enumerate (word_list )}
16
+ n_class = len (word_dict )
17
17
18
18
# Parameter
19
- max_len = 20
19
+ max_len = 5 # 'S' or 'E' will be added (= n_step,seq_len)
20
20
n_hidden = 128
21
- total_epoch = 10000
22
- n_class = dic_len
21
+ batch_size = 1
23
22
24
- seq_data = [['Ich mochte ein bier' , 'I want a BEER' ]]
23
+ def make_batch (sentences ):
24
+ input_batch = [np .eye (n_class )[[word_dict [n ] for n in sentences [0 ].split ()]]]
25
+ output_batch = [np .eye (n_class )[[word_dict [n ] for n in sentences [1 ].split ()]]]
26
+ target_batch = [[word_dict [n ] for n in sentences [2 ].split ()]]
25
27
26
- def make_batch (seq_data ):
27
- input_batch = []
28
- output_batch = []
29
- target_batch = []
28
+ # make tensor
29
+ return Variable (torch .Tensor (input_batch )), Variable (torch .Tensor (output_batch )), Variable (torch .LongTensor (target_batch ))
30
+
31
+ class Attention (nn .Module ):
32
+ def __init__ (self ):
33
+ super (Attention , self ).__init__ ()
34
+ self .enc_cell = nn .RNN (input_size = n_class , hidden_size = n_hidden , dropout = 0.5 )
35
+ self .dec_cell = nn .RNN (input_size = n_class , hidden_size = n_hidden , dropout = 0.5 )
36
+
37
+ # Linear for attention
38
+ self .attn = nn .Linear (n_hidden , n_hidden )
39
+
40
+ def forward (self , enc_input , hidden , dec_input ):
41
+ enc_input = enc_input .transpose (0 , 1 ) # enc_input: [max_len(=n_step, time step), batch_size, n_hidden]
42
+ dec_input = dec_input .transpose (0 , 1 ) # dec_input: [max_len(=n_step, time step), batch_size, n_hidden]
30
43
31
- for seq in seq_data :
32
- for i in range (2 ):
33
- seq [i ] = seq [i ] + 'P' * (max_len - len (seq [i ]))
44
+ # enc_outputs : [max_len, batch_size, num_directions(=1) * n_hidden(=1)]
45
+ # enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
46
+ enc_outputs , enc_states = self .enc_cell (enc_input , hidden )
47
+ dec_outputs , _ = self .dec_cell (dec_input , enc_states )
34
48
35
- input = [num_dic [n ] for n in seq [0 ]]
36
- output = [num_dic [n ] for n in ('S' + seq [1 ])]
37
- target = [num_dic [n ] for n in (seq [1 ] + 'E' )]
49
+
50
+ return dec_outputs
51
+
52
+ def get_att_weight (self , hidden , enc_outputs ):
53
+ attn_scores = Variable (torch .zeros (len (enc_outputs ))) # attn_scores : [n_step]
38
54
39
- input_batch .append (np .eye (dic_len )[input ])
40
- output_batch .append (np .eye (dic_len )[output ])
55
+ def get_att_score (self , hidden , encoder_hidden ):
56
+ score = self .attn (encoder_hidden )
57
+ return torch .dot (hidden .view (- 1 ), score .view (- 1 ))
41
58
42
- target_batch . append ( target )
59
+ input_batch , output_batch , target_batch = make_batch ( sentences )
43
60
44
- return input_batch , output_batch , target_batch
61
+ # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
62
+ hidden = Variable (torch .zeros (1 , 1 , n_hidden ))
63
+
64
+ model = Attention ()
65
+ output = model (input_batch , hidden , output_batch )
0 commit comments