|
|
|
@ -26,7 +26,7 @@ def train(args): |
|
|
|
|
sys.exit(-1) |
|
|
|
|
|
|
|
|
|
contracts = load_contracts(dirname, ext=ext, nsamples=nsamples) |
|
|
|
|
logger.info('Saving extracted data into', last_data_train_filename) |
|
|
|
|
logger.info('Saving extracted data into %s', last_data_train_filename) |
|
|
|
|
cache = [] |
|
|
|
|
with open(last_data_train_filename, 'w') as f: |
|
|
|
|
for filename in contracts: |
|
|
|
@ -40,13 +40,15 @@ def train(args): |
|
|
|
|
logger.info('Starting training') |
|
|
|
|
model = train_unsupervised(input=last_data_train_filename, model='skipgram') |
|
|
|
|
logger.info('Training complete') |
|
|
|
|
logger.info('Saving model') |
|
|
|
|
model.save_model(model_filename) |
|
|
|
|
|
|
|
|
|
for i,(filename, contract, function, irs) in enumerate(cache): |
|
|
|
|
cache[i] = ((filename, contract, function), model.get_sentence_vector(irs)) |
|
|
|
|
|
|
|
|
|
logger.info('Saved cache in cache.npz') |
|
|
|
|
logger.info('Saving cache in cache.npz') |
|
|
|
|
save_cache(cache, "cache.npz") |
|
|
|
|
logger.info('Done!') |
|
|
|
|
|
|
|
|
|
except Exception: |
|
|
|
|
logger.error('Error in %s' % args.filename) |
|
|
|
|