Skip to content

Commit

Permalink
Add tests for Word2Vec
Browse files Browse the repository at this point in the history
- validates that there are no leaked tensors
- validates basic functionality of nearest
- validates that add, subtract and average return things

- Fixes memory leaks in add, subtract, average, and addOrSubtract functions
- Adds a general dispose to the Word2Vec class
  • Loading branch information
meiamsome committed Jun 26, 2018
1 parent d8bb936 commit f8b83e2
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 50 deletions.
75 changes: 44 additions & 31 deletions src/Word2vec/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,31 @@ class Word2Vec {
});
}

dispose() {
Object.values(this.model).forEach(x => x.dispose());
}

add(inputs, max = 1) {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max);
return tf.tidy(() => {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
console.log(sum);
return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max);
});
}

subtract(inputs, max = 1) {
const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT');
return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max);
return tf.tidy(() => {
const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT');
return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max);
});
}

average(inputs, max = 1) {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
const avg = tf.div(sum, tf.tensor(inputs.length));
return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max);
return tf.tidy(() => {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
const avg = tf.div(sum, tf.tensor(inputs.length));
return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max);
});
}

nearest(input, max = 10) {
Expand All @@ -64,34 +75,36 @@ class Word2Vec {
}

static addOrSubtract(model, values, operation) {
const vectors = [];
const notFound = [];
if (values.length < 2) {
throw new Error('Invalid input, must be passed more than 1 value');
}
values.forEach((value) => {
const vector = model[value];
if (!vector) {
notFound.push(value);
} else {
vectors.push(vector);
return tf.tidy(() => {
const vectors = [];
const notFound = [];
if (values.length < 2) {
throw new Error('Invalid input, must be passed more than 1 value');
}
});
values.forEach((value) => {
const vector = model[value];
if (!vector) {
notFound.push(value);
} else {
vectors.push(vector);
}
});

if (notFound.length > 0) {
throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`);
}
let result = vectors[0];
if (operation === 'ADD') {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.add(result, vectors[i]);
if (notFound.length > 0) {
throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`);
}
} else {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.sub(result, vectors[i]);
let result = vectors[0];
if (operation === 'ADD') {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.add(result, vectors[i]);
}
} else {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.sub(result, vectors[i]);
}
}
}
return result;
return result;
});
}

static nearest(model, input, start, max) {
Expand Down
115 changes: 96 additions & 19 deletions src/Word2vec/index_test.js
Original file line number Diff line number Diff line change
@@ -1,19 +1,96 @@
// import Word2Vec from './index';

// const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/07_Word2Vec/data/wordvecs1000.json';

// describe('initialize word2vec', () => {
// let word2vec;
// beforeAll((done) => {
// // word2vec = new Word2Vec(URL);
// done();
// });

// // it('creates a new instance', (done) => {
// // expect(word2vec).toEqual(jasmine.objectContaining({
// // ready: true,
// // modelSize: 1,
// // }));
// // done();
// // });
// });
const { tf, word2vec } = ml5;

const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/Word2Vec/data/wordvecs1000.json';

describe('word2vec', () => {
let word2vecInstance;
let numTensorsBeforeAll;
let numTensorsBeforeEach;
beforeAll((done) => {
numTensorsBeforeAll = tf.memory().numTensors;
word2vecInstance = word2vec(URL, done);
});

afterAll(() => {
word2vecInstance.dispose();
let numTensorsAfterAll = tf.memory().numTensors;
if(numTensorsBeforeAll !== numTensorsAfterAll) {
throw new Error(`Leaking Tensors (${numTensorsAfterAll} vs ${numTensorsBeforeAll})`);
}
});

beforeEach(() => {
numTensorsBeforeEach = tf.memory().numTensors;
});

afterEach(() => {
let numTensorsAfterEach = tf.memory().numTensors;
if(numTensorsBeforeEach !== numTensorsAfterEach) {
throw new Error(`Leaking Tensors (${numTensorsAfterEach} vs ${numTensorsBeforeEach})`);
}
});

it('creates a new instance', () => {
expect(word2vecInstance).toEqual(jasmine.objectContaining({
ready: true,
modelSize: 1,
}));
});

describe('getRandomWord', () => {
it('returns a word', () => {
let word = word2vecInstance.getRandomWord();
expect(typeof word).toEqual('string');
});
});

describe('nearest', () => {
it('returns a sorted array of nearest words', () => {
for(let i = 0; i < 100; i++) {
let word = word2vecInstance.getRandomWord();
let nearest = word2vecInstance.nearest(word);
let currentDistance = 0;
for(let { word, distance: nextDistance } of nearest) {
expect(typeof word).toEqual('string');
expect(nextDistance).toBeGreaterThan(currentDistance);
currentDistance = nextDistance;
}
}
});

it('returns a list of the right length', () => {
for(let i = 0; i < 100; i++) {
let word = word2vecInstance.getRandomWord();
let nearest = word2vecInstance.nearest(word, i);
expect(nearest.length).toEqual(i);
}
});
});

describe('add', () => {
it('returns a value', () => {
let word1 = word2vecInstance.getRandomWord();
let word2 = word2vecInstance.getRandomWord();
let sum = word2vecInstance.subtract([word1, word2]);
expect(sum[0].distance).toBeGreaterThan(0);
})
});

describe('subtract', () => {
it('returns a value', () => {
let word1 = word2vecInstance.getRandomWord();
let word2 = word2vecInstance.getRandomWord();
let sum = word2vecInstance.subtract([word1, word2]);
expect(sum[0].distance).toBeGreaterThan(0);
})
});

describe('average', () => {
it('returns a value', () => {
let word1 = word2vecInstance.getRandomWord();
let word2 = word2vecInstance.getRandomWord();
let average = word2vecInstance.average([word1, word2]);
expect(average[0].distance).toBeGreaterThan(0);
});
});
});

0 comments on commit f8b83e2

Please sign in to comment.