Skip to content

Commit

Permalink
add lr scheduler, configuration change, test code writing
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunwoongko committed Dec 18, 2019
1 parent fae3e2f commit de85b22
Show file tree
Hide file tree
Showing 18 changed files with 108 additions and 72 deletions.
26 changes: 18 additions & 8 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

My own implementation Transformer model (Attention is All You Need - Google Brain, 2017)
<br><br>
![model](image/1.png)
![model](image/model.png)
<br><br>

## Experiments
![model](image/train_result.jpg)



## Reference
|Reference|
|:---:|
Expand Down
Binary file modified __pycache__/conf.cpython-36.pyc
Binary file not shown.
Binary file modified __pycache__/data.cpython-36.pyc
Binary file not shown.
10 changes: 6 additions & 4 deletions conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# model parameter setting
batch_size = 128
batch_size = 64
max_len = 512
d_model = 512
n_layers = 6
Expand All @@ -18,8 +18,10 @@
drop_prob = 0.1

# optimizer parameter setting
init_lr = 1e-4
weight_decay = 5e-4
epoch = 300
init_lr = 1e-5
factor = 0.8
patience = 7
weight_decay = 5e-3
epoch = 3000
clip = 1
inf = float('inf')
2 changes: 1 addition & 1 deletion data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from util.tokenizer import Tokenizer

tokenizer = Tokenizer()
loader = DataLoader(ext=('.en', '.de'),
loader = DataLoader(ext=('.de', '.en'),
tokenize_en=tokenizer.tokenize_en,
tokenize_de=tokenizer.tokenize_de,
init_token='<sos>',
Expand Down
9 changes: 6 additions & 3 deletions graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def read(name):
file = re.sub('\\]', '', file)
f.close()

return [float(i) * 100.0 for i in file.split(',')]
return [float(i) for idx, i in enumerate(file.split(',')) if idx <= 150]


train = read('./result/train.txt')
Expand All @@ -28,5 +28,8 @@ def read(name):
plt.ylabel('loss')
plt.title('training result')
plt.grid(True, which='both', axis='both')
plt.legend(loc='lower right')
plt.show()
plt.legend(loc='lower left')
plt.xticks([i for i in range(0, 151, 10)])
plt.yticks([i * 0.2 for i in range(17, 30)])

plt.show()
File renamed without changes
Binary file added image/train_result.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 0 additions & 10 deletions main.py

This file was deleted.

2 changes: 1 addition & 1 deletion result/test.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
[5.107369244098663]
[5.107369244098663, 5.047159373760223, 4.974853754043579, 4.995905220508575, 4.889649212360382, 4.931996583938599, 4.79099440574646, 4.610719561576843, 4.704448044300079, 4.478800296783447, 4.608596563339233, 4.411366403102875, 4.514317214488983, 4.516281008720398, 4.406989753246307, 4.398191511631012, 4.469792008399963, 4.439737498760223, 4.440478563308716, 4.456247925758362, 4.409370839595795, 4.527897357940674, 4.555103003978729, 4.448478937149048, 4.434713065624237, 4.462775468826294, 4.483190834522247, 4.462384939193726, 4.364038109779358, 4.404552459716797, 4.382501840591431, 4.291230112314224, 4.290549635887146, 4.292933970689774, 4.355917453765869, 4.360010951757431, 4.438632607460022, 4.35925555229187, 4.388172149658203, 4.309739708900452, 4.39450865983963, 4.3459999561309814, 4.312599867582321, 4.311552673578262, 4.3432484567165375, 4.369180917739868, 4.232648283243179, 4.226952880620956, 4.3540546000003815, 4.311228662729263, 4.265764266252518, 4.285617887973785, 4.333370089530945, 4.350599020719528, 4.257741242647171, 4.265554070472717, 4.314047962427139, 4.473091185092926, 4.380516469478607, 4.365982949733734, 4.413721024990082, 4.37661549448967, 4.441282421350479, 4.387831300497055, 4.382125526666641, 4.458192557096481, 4.5256989896297455, 4.4446702003479, 4.4245529770851135, 4.374544709920883, 4.623082458972931, 4.5312294363975525, 4.561788409948349, 4.5677085518836975, 4.527444243431091, 4.600746214389801, 4.4941065311431885, 4.508937358856201, 4.535274356603622, 4.506582468748093, 4.424302697181702, 4.482682317495346, 4.3892485201358795, 4.430088222026825, 4.3308519423007965, 4.2865891456604, 4.449585944414139, 4.3190436363220215, 4.351876676082611, 4.340744376182556, 4.265190571546555, 4.36338272690773, 4.310745090246201, 4.182569473981857, 4.182712346315384, 4.263973087072372, 4.131220132112503, 4.384204000234604, 4.3515565097332, 4.242527782917023, 4.078160256147385, 4.1080431044101715, 4.249444901943207, 4.073768049478531, 4.008112043142319, 4.139866441488266, 4.118567109107971, 4.009009003639221, 3.973883777856827, 4.028521358966827, 4.027924060821533, 3.941185712814331, 3.942689299583435, 4.000838190317154, 3.953184276819229, 3.9298836290836334, 3.9049587547779083, 3.935948431491852, 3.8946973085403442, 3.873660296201706, 3.855608344078064, 3.9091764390468597, 3.949636697769165, 3.977479934692383, 3.9306695461273193, 3.9233385026454926, 3.8942521810531616, 3.9339661598205566, 3.955780267715454, 3.9690974950790405, 3.9635704457759857, 3.9262157678604126, 3.8931693136692047, 3.976430982351303, 3.900678277015686, 3.9732322692871094, 3.961668938398361, 3.873569905757904, 3.9164712131023407, 4.055065453052521, 3.908220112323761, 3.8862161934375763, 4.0920418202877045, 4.01253941655159, 3.8541969060897827, 3.8702659010887146, 3.880065083503723, 3.8091674149036407, 3.863888144493103, 3.8924776911735535, 3.8309891521930695, 3.7834676802158356, 3.7963992953300476, 3.7478930950164795, 3.8485116958618164, 3.86627459526062, 3.7579129338264465, 3.7750843167304993, 3.7846836149692535, 3.823878973722458, 3.7668022513389587, 3.734569400548935, 3.826398581266403, 3.8080168962478638, 3.8877736032009125, 3.7764706313610077, 3.804745376110077, 3.773357003927231, 3.8848529160022736, 3.7165520787239075, 3.852197676897049, 3.875643253326416, 3.8159318566322327, 3.8359116315841675, 3.8654544055461884, 3.90069779753685, 3.781153082847595, 3.8338528275489807, 3.8655340373516083, 3.794357180595398, 3.812661439180374, 3.8944612443447113, 3.8269618451595306, 3.777755856513977, 3.7732381522655487, 3.784351259469986, 3.8030809462070465, 3.918800562620163, 3.858412981033325, 3.813141107559204, 3.836715519428253, 3.8063608407974243, 3.8273969888687134, 3.7943951189517975, 3.8174937069416046, 3.783056229352951, 3.809193044900894, 3.794726252555847, 3.7999748289585114, 3.7938650846481323, 3.919620633125305, 3.8781287372112274, 3.8806531131267548, 3.978741854429245, 3.8254112005233765, 3.836581826210022, 3.81705105304718, 3.9391154050827026, 3.81963711977005, 3.7944883704185486, 3.828608989715576, 3.810883104801178, 3.9169464111328125, 3.80493888258934, 3.841265231370926, 3.8392494916915894, 3.8767609894275665, 4.004000186920166, 3.8871565461158752, 3.8163305521011353, 3.9872564673423767, 3.8903215527534485, 3.9109255373477936, 3.9021213948726654, 3.7755310237407684, 3.9322194159030914, 3.805501401424408, 3.8840301632881165, 3.856250613927841, 3.9011684954166412, 3.860666126012802, 3.8833804428577423, 3.885105609893799, 3.8913139402866364, 3.8395919501781464, 3.9061848521232605, 3.9559248983860016, 3.912041962146759, 3.92180860042572, 3.810354858636856, 3.9449867010116577, 3.933736652135849, 3.91291007399559, 3.888887971639633, 3.9428194761276245, 3.9123726189136505, 3.862858086824417, 3.9613552689552307, 3.8319855630397797, 3.9507904052734375, 3.932381808757782, 3.8407545685768127, 3.9586107432842255, 3.8897804915905, 3.8290930688381195, 3.890099674463272, 3.84007328748703, 3.8194033205509186, 3.8897750675678253, 3.922900140285492, 3.8933580219745636, 3.9606508910655975, 3.89353147149086, 3.838844984769821, 3.9335989356040955, 3.9013270139694214, 3.8215660452842712, 3.8741520643234253, 3.88197860121727, 3.8777244985103607, 3.8509004414081573, 3.873890608549118, 3.778834253549576, 3.870293378829956, 3.868007928133011, 3.8560746014118195, 3.9366289377212524, 3.829409956932068, 3.8308649361133575, 3.836451083421707, 3.9306696355342865, 3.862416088581085, 3.8342210948467255, 3.8261405527591705, 3.87460196018219, 3.870723158121109, 3.8528255820274353, 3.8280787467956543, 3.8979072868824005, 3.8674525320529938, 3.9424744844436646, 3.855466842651367, 3.7723155319690704, 3.932773530483246, 3.831766039133072, 3.8109220564365387, 3.788826107978821, 3.898903965950012, 3.9026650190353394, 3.8430078327655792]
Loading

0 comments on commit de85b22

Please sign in to comment.