Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Word2Vec tests #173

Merged
merged 4 commits into from
Jun 27, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add tests for Word2Vec
- validates that there are no leaked tensors
- validates basic functionality of nearest
- validates that add, subtract and average return things

- Fixes memory leaks in add, subtract, average, and addOrSubtract functions
- Adds a general dispose to the Word2Vec class
  • Loading branch information
meiamsome committed Jun 26, 2018
commit f8b83e27dd6979d15b279ccdc5800636336b32fc
75 changes: 44 additions & 31 deletions src/Word2vec/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,31 @@ class Word2Vec {
});
}

dispose() {
Object.values(this.model).forEach(x => x.dispose());
}

add(inputs, max = 1) {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max);
return tf.tidy(() => {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
console.log(sum);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you forgot the remove this log

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch :)

return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max);
});
}

subtract(inputs, max = 1) {
const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT');
return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max);
return tf.tidy(() => {
const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT');
return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max);
});
}

average(inputs, max = 1) {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
const avg = tf.div(sum, tf.tensor(inputs.length));
return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max);
return tf.tidy(() => {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
const avg = tf.div(sum, tf.tensor(inputs.length));
return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max);
});
}

nearest(input, max = 10) {
Expand All @@ -64,34 +75,36 @@ class Word2Vec {
}

static addOrSubtract(model, values, operation) {
const vectors = [];
const notFound = [];
if (values.length < 2) {
throw new Error('Invalid input, must be passed more than 1 value');
}
values.forEach((value) => {
const vector = model[value];
if (!vector) {
notFound.push(value);
} else {
vectors.push(vector);
return tf.tidy(() => {
const vectors = [];
const notFound = [];
if (values.length < 2) {
throw new Error('Invalid input, must be passed more than 1 value');
}
});
values.forEach((value) => {
const vector = model[value];
if (!vector) {
notFound.push(value);
} else {
vectors.push(vector);
}
});

if (notFound.length > 0) {
throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`);
}
let result = vectors[0];
if (operation === 'ADD') {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.add(result, vectors[i]);
if (notFound.length > 0) {
throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`);
}
} else {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.sub(result, vectors[i]);
let result = vectors[0];
if (operation === 'ADD') {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.add(result, vectors[i]);
}
} else {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.sub(result, vectors[i]);
}
}
}
return result;
return result;
});
}

static nearest(model, input, start, max) {
Expand Down
115 changes: 96 additions & 19 deletions src/Word2vec/index_test.js
Original file line number Diff line number Diff line change
@@ -1,19 +1,96 @@
// import Word2Vec from './index';

// const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/07_Word2Vec/data/wordvecs1000.json';

// describe('initialize word2vec', () => {
// let word2vec;
// beforeAll((done) => {
// // word2vec = new Word2Vec(URL);
// done();
// });

// // it('creates a new instance', (done) => {
// // expect(word2vec).toEqual(jasmine.objectContaining({
// // ready: true,
// // modelSize: 1,
// // }));
// // done();
// // });
// });
const { tf, word2vec } = ml5;

const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/Word2Vec/data/wordvecs1000.json';

describe('word2vec', () => {
let word2vecInstance;
let numTensorsBeforeAll;
let numTensorsBeforeEach;
beforeAll((done) => {
numTensorsBeforeAll = tf.memory().numTensors;
word2vecInstance = word2vec(URL, done);
});

afterAll(() => {
word2vecInstance.dispose();
let numTensorsAfterAll = tf.memory().numTensors;
if(numTensorsBeforeAll !== numTensorsAfterAll) {
throw new Error(`Leaking Tensors (${numTensorsAfterAll} vs ${numTensorsBeforeAll})`);
}
});

beforeEach(() => {
numTensorsBeforeEach = tf.memory().numTensors;
});

afterEach(() => {
let numTensorsAfterEach = tf.memory().numTensors;
if(numTensorsBeforeEach !== numTensorsAfterEach) {
throw new Error(`Leaking Tensors (${numTensorsAfterEach} vs ${numTensorsBeforeEach})`);
}
});

it('creates a new instance', () => {
expect(word2vecInstance).toEqual(jasmine.objectContaining({
ready: true,
modelSize: 1,
}));
});

describe('getRandomWord', () => {
it('returns a word', () => {
let word = word2vecInstance.getRandomWord();
expect(typeof word).toEqual('string');
});
});

describe('nearest', () => {
it('returns a sorted array of nearest words', () => {
for(let i = 0; i < 100; i++) {
let word = word2vecInstance.getRandomWord();
let nearest = word2vecInstance.nearest(word);
let currentDistance = 0;
for(let { word, distance: nextDistance } of nearest) {
expect(typeof word).toEqual('string');
expect(nextDistance).toBeGreaterThan(currentDistance);
currentDistance = nextDistance;
}
}
});

it('returns a list of the right length', () => {
for(let i = 0; i < 100; i++) {
let word = word2vecInstance.getRandomWord();
let nearest = word2vecInstance.nearest(word, i);
expect(nearest.length).toEqual(i);
}
});
});

describe('add', () => {
it('returns a value', () => {
let word1 = word2vecInstance.getRandomWord();
let word2 = word2vecInstance.getRandomWord();
let sum = word2vecInstance.subtract([word1, word2]);
expect(sum[0].distance).toBeGreaterThan(0);
})
});

describe('subtract', () => {
it('returns a value', () => {
let word1 = word2vecInstance.getRandomWord();
let word2 = word2vecInstance.getRandomWord();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it technically possible the same word could be picked twice?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, but the probability is 0.0001%

let sum = word2vecInstance.subtract([word1, word2]);
expect(sum[0].distance).toBeGreaterThan(0);
})
});

describe('average', () => {
it('returns a value', () => {
let word1 = word2vecInstance.getRandomWord();
let word2 = word2vecInstance.getRandomWord();
let average = word2vecInstance.average([word1, word2]);
expect(average[0].distance).toBeGreaterThan(0);
});
});
});