create_documents.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Apr 16 14:54:31 2021
  4. @author: kane
  5. """
  6. # /home/Anaconda3/bin/es-serving-start -model_dir /home/NLP_bert/chinese_roberta_wwm_ext_L-12_H-768_A-12 -max_seq_len=120 -max_batch_size=30
  7. import argparse
  8. import json
  9. import pandas as pd
  10. from bert_serving.client import BertClient
  11. server_ip = "192.168.20.68"
  12. bc = BertClient(ip=server_ip, output_fmt='list')
  13. def create_document(doc, emb, index_name):
  14. return {
  15. '_op_type': 'index',
  16. '_index': index_name,
  17. 'phenomenon': doc["phenomenon"],
  18. 'phenomenon_vector': emb
  19. }
  20. def load_dataset(path):
  21. docs = []
  22. df = pd.read_csv(path)
  23. for row in df.iterrows():
  24. series = row[1]
  25. doc = {
  26. 'phenomenon': series.phenomenon,
  27. }
  28. docs.append(doc)
  29. return docs
  30. def bulk_predict(docs, batch_size=256):
  31. """Predict es embeddings."""
  32. for i in range(0, len(docs), batch_size):
  33. batch_docs = docs[i: i + batch_size]
  34. breakdown_show_embeddings = bc.encode([doc['phenomenon'] for doc in batch_docs])
  35. for emb in breakdown_show_embeddings:
  36. yield emb
  37. def main(args):
  38. docs = load_dataset(args.data)
  39. with open(args.save, 'w') as f:
  40. for doc, emb in zip(docs, bulk_predict(docs)):
  41. d = create_document(doc, emb, args.index_name)
  42. f.write(json.dumps(d) + '\n')
  43. # bert转换的结果写到文件中
  44. if __name__ == '__main__':
  45. parser = argparse.ArgumentParser(description='Creating elasticsearch documents.')
  46. parser.add_argument('--data', default='phenomenon.csv', help='data for creating documents.')
  47. parser.add_argument('--save', default='documents.jsonl', help='created documents.')
  48. parser.add_argument('--index_name', default='fault_meter', help='Elasticsearch index name.')
  49. args = parser.parse_args()
  50. main(args)