xlnet Chinese version of the pre training model has finally come out, see the address. https://github.com/ymcui/Chinese-PreTrained-XLNet After coming out, we tried the Chinese text categorization model. The xlnet model has changed many things compared with Bert. The model level is not much. The current Chinese text categorization model is based on 24 layers of network results, which is two times larger than the Chinese version of the bert12 layer network. There were many papers before it came out, mainly the Chinese data processing problem, and the Sen used in the model. Tencepiece participle, pad mode uses post-padding mode, model input is in the form of len*batch, and some segment_ids and mask s are different from the ordinary model. Let's look at the code directly below.
Data is converted to tfrecord:
import tensorflow as tf import sys import six import unicodedata import sentencepiece as spm import collections from textclass import FLAGS SEG_ID_A = 0 SEG_ID_B = 1 SEG_ID_CLS = 2 SEG_ID_SEP = 3 SEG_ID_PAD = 4 special_symbols = { "<unk>" : 0, "<s>" : 1, "</s>" : 2, "<cls>" : 3, "<sep>" : 4, "<pad>" : 5, "<mask>" : 6, "<eod>" : 7, "<eop>" : 8, } VOCAB_SIZE = 32000 UNK_ID = special_symbols["<unk>"] CLS_ID = special_symbols["<cls>"] SEP_ID = special_symbols["<sep>"] MASK_ID = special_symbols["<mask>"] EOD_ID = special_symbols["<eod>"] sp = spm.SentencePieceProcessor() sp.Load(FLAGS.spiece_model_file) def _truncate_seq_pair(tokens_a, tokens_b, max_length): while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_length: break if len(tokens_a) > len(tokens_b): tokens_a.pop() else: tokens_b.pop() def get_class_ids(text,max_seq_length,tokenize_fn): texts = tokenize_fn(text) if len(texts) > max_seq_length - 2: texts = texts[:max_seq_length - 2] tokens = [] segment_ids = [] for token in texts: tokens.append(token) segment_ids.append(SEG_ID_A) tokens.append(SEP_ID) segment_ids.append(SEG_ID_A) tokens.append(CLS_ID) segment_ids.append(SEG_ID_CLS) input_ids = tokens input_mask = [0] * len(input_ids) if len(input_ids) < max_seq_length: delta_len = max_seq_length - len(input_ids) input_ids = [0] * delta_len + input_ids input_mask = [1] * delta_len + input_mask segment_ids = [SEG_ID_PAD] * delta_len + segment_ids assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length return input_ids,input_mask,segment_ids def get_pair_ids(text_a,text_b,max_seq_length,tokenize_fn): tokens_a = tokenize_fn(text_a) tokens_b = tokenize_fn(text_b) _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) tokens = [] segment_ids = [] for token in tokens_a: tokens.append(token) segment_ids.append(SEG_ID_A) tokens.append(SEP_ID) segment_ids.append(SEG_ID_A) for token in tokens_b: tokens.append(token) segment_ids.append(SEG_ID_B) tokens.append(SEP_ID) segment_ids.append(SEG_ID_B) tokens.append(CLS_ID) segment_ids.append(SEG_ID_CLS) input_ids = tokens input_mask = [0] * len(input_ids) if len(input_ids) < max_seq_length: delta_len = max_seq_length - len(input_ids) input_ids = [0] * delta_len + input_ids input_mask = [1] * delta_len + input_mask segment_ids = [SEG_ID_PAD] * delta_len + segment_ids assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length return input_ids,input_mask,segment_ids SPIECE_UNDERLINE = '▁' def encode_pieces(sp_model, text, return_unicode=True, sample=False): if six.PY2 and isinstance(text, unicode): text = text.encode('utf-8') if not sample: pieces = sp_model.EncodeAsPieces(text) else: pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) new_pieces = [] for piece in pieces: if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): cur_pieces = sp_model.EncodeAsPieces( piece[:-1].replace(SPIECE_UNDERLINE, '')) if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: if len(cur_pieces[0]) == 1: cur_pieces = cur_pieces[1:] else: cur_pieces[0] = cur_pieces[0][1:] cur_pieces.append(piece[-1]) new_pieces.extend(cur_pieces) else: new_pieces.append(piece) # note(zhiliny): convert back to unicode for py2 if six.PY2 and return_unicode: ret_pieces = [] for piece in new_pieces: if isinstance(piece, str): piece = piece.decode('utf-8') ret_pieces.append(piece) new_pieces = ret_pieces return new_pieces def encode_ids(sp_model, text, sample=False): pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) ids = [sp_model.PieceToId(piece) for piece in pieces] return ids def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False): if remove_space: outputs = ' '.join(inputs.strip().split()) else: outputs = inputs outputs = outputs.replace("``", '"').replace("''", '"') if six.PY2 and isinstance(outputs, str): outputs = outputs.decode('utf-8') if not keep_accents: outputs = unicodedata.normalize('NFKD', outputs) outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) if lower: outputs = outputs.lower() return outputs def tokenize_fn(text): text = preprocess_text(text, lower=True) return encode_ids(sp, text) def get_vocab(path): maps = collections.defaultdict() i = 0 with tf.gfile.GFile(path, "r") as f: for line in f.readlines(): maps[line.strip()] = i i = i + 1 f.close() return maps def writedataclass(inputpath, vocab, outputpath,max_seq_length,tokenize_fn): eachonum = 5000 num = 0 recordfilenum = 0 ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum) writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename) with open(inputpath) as f: for text in f.readlines(): texts = text.split("\t") content= texts[0].lower().strip() label = vocab.get(texts[1].strip()) num = num + 1 input_ids,input_mask,segment_ids=get_class_ids(content, max_seq_length, tokenize_fn) if num > eachonum: num = 1 recordfilenum = recordfilenum + 1 ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum) writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename) example = tf.train.Example( features=tf.train.Features( feature={'input_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=input_ids)), 'input_mask': tf.train.Feature(int64_list=tf.train.Int64List(value=input_mask)), 'segment_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=segment_ids)), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])) })) serialized = example.SerializeToString() writer.write(serialized) writer.close() f.close()
I wrote a text categorization class. Look at it:
class XlnetReadingClass(object): def __init__(self,model_config_path,is_training,FLAGS,input_ids,segment_ids, input_mask,label,n_class): self.xlnet_config = xlnet.XLNetConfig(json_path=model_config_path) self.run_config = xlnet.create_run_config(is_training, True, FLAGS) self.input_ids=tf.transpose(input_ids,[1,0]) self.segment_ids = tf.transpose(segment_ids, [1, 0]) self.input_mask = tf.transpose(input_mask, [1, 0]) self.model = xlnet.XLNetModel( xlnet_config=self.xlnet_config, run_config=self.run_config, input_ids=self.input_ids, seg_ids=self.segment_ids, input_mask=self.input_mask) cls_scope = FLAGS.cls_scope summary = self.model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) self.per_example_loss, self.logits = modeling.classification_loss( hidden=summary, labels=label, n_class=n_class, initializer=self.model.get_initializer(), scope=cls_scope, return_logits=True) self.total_loss = tf.reduce_mean(self.per_example_loss) with tf.name_scope("train_op"): self.train_op, _, _ = model_utils.get_train_op(FLAGS, self.total_loss) with tf.name_scope("acc"): one_hot_target = tf.one_hot(label, n_class) self.acc=self.accuracy(self.logits,one_hot_target) def accuracy(self,logits, labels): arglabels_ = tf.argmax(tf.nn.softmax(logits), 1) arglabels = tf.argmax(tf.squeeze(labels), 1) acc = tf.to_float(tf.equal(arglabels_, arglabels)) return tf.reduce_mean(acc) def main(_): print('Loading config...') n_class = 38 input_path = FLAGS.data_dir + "xlnetreading.tfrecords*" print("input_path:", input_path) files = tf.train.match_filenames_once(input_path) """ inputs Is your data input path """ input_ids, input_mask, segment_ids, label_ids = inputs(files, batch_size=FLAGS.batch_size, num_epochs=5,max_seq_length=FLAGS.max_seq_length) model_config_path=FLAGS.model_config_path is_training=False init_checkpoint = FLAGS.init_checkpoint model = XlnetReadingClass(model_config_path, is_training,FLAGS, input_ids , segment_ids,input_mask, label_ids, n_class) tvars = tf.trainable_variables() if init_checkpoint: (assignment_map, initialized_variable_names) = model_utils.get_assignment_map_from_checkpoint(tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) print("restore sucess on cpu or gpu") session = tf.Session() session.run(tf.global_variables_initializer()) session.run(tf.local_variables_initializer()) print("**** Trainable Variables ****") for var in tvars: if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" print("name ={0}, shape = {1}{2}".format(var.name, var.shape, init_string)) print("xlnet reading class model will start train .........") print(session.run(files)) saver = tf.train.Saver() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord, sess=session) start_time = time.time() for i in range(8000): _, loss_train, acc = session.run([model.train_op, model.total_loss, model.acc]) if i % 100 == 0: end_time = time.time() time_dif = end_time - start_time time_dif = timedelta(seconds=int(round(time_dif))) msg = 'Iter: {0:>6}, Train Loss: {1:>6.2},' \ + ' Cost: {2} Time:{3} acc:{4}' print(msg.format(i, loss_train, time_dif, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), acc)) start_time = time.time() if i % 500 == 0 and i > 0: saver.save(session, "../exp/reading/model.ckpt", global_step=i) coord.request_stop() coord.join(threads) session.close()