forked from NTMC-Community/MatchZoo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
{ | ||
"net_name": "CONVKNRM", | ||
"global":{ | ||
"model_type": "PY", | ||
"weights_file": "examples/wikiqa/weights/knrm.wikiqa.weights", | ||
"save_weights_iters": 10, | ||
"num_iters": 100, | ||
"display_interval": 100, | ||
"test_weights_iters": 10, | ||
"optimizer": "adam", | ||
"learning_rate": 0.001 | ||
}, | ||
"inputs": { | ||
"share": { | ||
"text1_corpus": "./data/WikiQA/corpus_preprocessed.txt", | ||
"text2_corpus": "./data/WikiQA/corpus_preprocessed.txt", | ||
"use_dpool": false, | ||
"embed_size": 300, | ||
"embed_path": "./data/WikiQA/embed_glove_d300_norm", | ||
"vocab_size": 18680, | ||
"train_embed": false, | ||
"target_mode": "ranking", | ||
"text1_maxlen": 10, | ||
"text2_maxlen": 40 | ||
}, | ||
"train": { | ||
"input_type": "PairGenerator", | ||
"phase": "TRAIN", | ||
"use_iter": false, | ||
"query_per_iter": 50, | ||
"batch_per_iter": 5, | ||
"batch_size": 16, | ||
"relation_file": "./data/WikiQA/relation_train.txt" | ||
}, | ||
"valid": { | ||
"input_type": "ListGenerator", | ||
"phase": "EVAL", | ||
"batch_list": 10, | ||
"relation_file": "./data/WikiQA/relation_valid.txt" | ||
}, | ||
"test": { | ||
"input_type": "ListGenerator", | ||
"phase": "EVAL", | ||
"batch_list": 10, | ||
"relation_file": "./data/WikiQA/relation_test.txt" | ||
}, | ||
"predict": { | ||
"input_type": "ListGenerator", | ||
"phase": "PREDICT", | ||
"batch_list": 10, | ||
"relation_file": "./data/WikiQA/relation_test.txt" | ||
} | ||
}, | ||
"outputs": { | ||
"predict": { | ||
"save_format": "TREC", | ||
"save_path": "predict.test.knrm_ranking.txt" | ||
} | ||
}, | ||
"model": { | ||
"model_path": "./matchzoo/models/", | ||
"model_py": "conv_knrm.CONVKNRM", | ||
"setting": { | ||
"kernel_num": 11, | ||
"sigma": 0.1, | ||
"exact_sigma": 0.001, | ||
"max_ngram": 3, | ||
"if_crossmatch": true | ||
} | ||
}, | ||
"losses": [ | ||
{ | ||
"object_name": "rank_hinge_loss", | ||
"object_params": { "margin": 1.0 } | ||
} | ||
], | ||
"metrics": [ "ndcg@3", "ndcg@5", "map" ] | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
cd ../../ | ||
|
||
currpath=`pwd` | ||
# train the model | ||
python matchzoo/main.py --phase train --model_file ${currpath}/examples/wikiqa/config/conv_knrm_wikiqa.config | ||
|
||
|
||
# predict with the model | ||
|
||
python matchzoo/main.py --phase predict --model_file ${currpath}/examples/wikiqa/config/conv_knrm_wikiqa.config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# -*- coding=utf-8 -*- | ||
import keras | ||
import keras.backend as K | ||
from keras.models import Sequential, Model | ||
from keras.layers import * | ||
from keras.layers import Input, Embedding, Dense, Activation, Lambda, Dot | ||
from keras.activations import softmax | ||
from keras.initializers import Constant, RandomNormal | ||
from model import BasicModel | ||
from utils.utility import * | ||
|
||
class CONVKNRM(BasicModel): | ||
def __init__(self, config): | ||
super(CONVKNRM, self).__init__(config) | ||
self._name = 'CONVKNRM' | ||
self.check_list = [ 'text1_maxlen', 'kernel_num','sigma','exact_sigma', | ||
'embed', 'embed_size', 'vocab_size', 'max_ngram', 'if_crossmatch'] | ||
self.setup(config) | ||
if not self.check(): | ||
raise TypeError('[ConvKNRM] parameter check wrong') | ||
print('[CONVKNRM] init done') | ||
|
||
def setup(self, config): | ||
self.set_default('kernel_num', 11) | ||
self.set_default('sigma', 0.1) | ||
self.set_default('exact_sigma', 0.001) | ||
self.set_default('max_ngram', 3) | ||
self.set_default('if_crossmatch', True) | ||
if not isinstance(config, dict): | ||
raise TypeError('parameter config should be dict:', config) | ||
self.config.update(config) | ||
|
||
def build(self): | ||
def Kernel_layer(mu,sigma): | ||
def kernel(x): | ||
return K.tf.exp(-0.5 * (x - mu) * (x - mu) / sigma / sigma) | ||
return Activation(kernel) | ||
|
||
query = Input(name='query', shape=(self.config['text1_maxlen'],)) | ||
show_layer_info('Input', query) | ||
doc = Input(name='doc', shape=(self.config['text2_maxlen'],)) | ||
show_layer_info('Input', doc) | ||
|
||
embedding = Embedding(self.config['vocab_size'], self.config['embed_size'], weights=[self.config['embed']], trainable=self.config['train_embed']) | ||
|
||
q_embed = embedding(query) | ||
show_layer_info('Embedding', q_embed) | ||
d_embed = embedding(doc) | ||
show_layer_info('Embedding', d_embed) | ||
convs = [] | ||
q_convs = [] | ||
d_convs = [] | ||
for i in range(self.config['max_ngram']): | ||
c = keras.layers.Conv1D(128, i + 1, activation='relu', padding='same') | ||
q_convs.append(c(q_embed) ) | ||
show_layer_info('Q N-gram Embedding', q_convs[i]) | ||
d_convs.append(c(d_embed) ) | ||
show_layer_info('D N-gram Embedding', d_convs[i]) | ||
|
||
KM = [] | ||
for qi in range(self.config['max_ngram']): | ||
for di in range(self.config['max_ngram']): | ||
# if not corssmatch, then do not match n-gram with different length | ||
if not self.config['if_crossmatch'] and qi != di: | ||
print ("non cross") | ||
continue | ||
q_ngram = q_convs[qi] | ||
d_ngram = d_convs[di] | ||
mm = Dot(axes=[2, 2], normalize=True)([q_ngram, d_ngram]) | ||
show_layer_info('Dot', mm) | ||
|
||
for i in range(self.config['kernel_num']): | ||
mu = 1. / (self.config['kernel_num'] - 1) + (2. * i) / (self.config['kernel_num'] - 1) - 1.0 | ||
sigma = self.config['sigma'] | ||
if mu > 1.0: | ||
sigma = self.config['exact_sigma'] | ||
mu = 1.0 | ||
mm_exp = Kernel_layer(mu, sigma)(mm) | ||
show_layer_info('Exponent of mm:', mm_exp) | ||
mm_doc_sum = Lambda(lambda x: K.tf.reduce_sum(x,2))(mm_exp) | ||
show_layer_info('Sum of document', mm_doc_sum) | ||
mm_log = Activation(K.tf.log1p)(mm_doc_sum) | ||
show_layer_info('Logarithm of sum', mm_log) | ||
mm_sum = Lambda(lambda x: K.tf.reduce_sum(x, 1))(mm_log) | ||
show_layer_info('Sum of all exponent', mm_sum) | ||
KM.append(mm_sum) | ||
|
||
|
||
Phi = Lambda(lambda x: K.tf.stack(x, 1))(KM) | ||
show_layer_info('Stack', Phi) | ||
if self.config['target_mode'] == 'classification': | ||
out_ = Dense(2, activation='softmax', kernel_initializer=initializers.RandomUniform(minval=-0.014, maxval=0.014), bias_initializer='zeros')(Phi) | ||
elif self.config['target_mode'] in ['regression', 'ranking']: | ||
out_ = Dense(1, kernel_initializer=initializers.RandomUniform(minval=-0.014, maxval=0.014), bias_initializer='zeros')(Phi) | ||
show_layer_info('Dense', out_) | ||
|
||
model = Model(inputs=[query, doc], outputs=[out_]) | ||
return model | ||
|
||
|