xlnet Chinese text classification task

Keywords: Session network github

xlnet Chinese version of the pre training model has finally come out, see the address. https://github.com/ymcui/Chinese-PreTrained-XLNet After coming out, we tried the Chinese text categorization model. The xlnet model has changed many things compared with Bert. The model level is not much. The current Chinese text categorization model is based on 24 layers of network results, which is two times larger than the Chinese version of the bert12 layer network. There were many papers before it came out, mainly the Chinese data processing problem, and the Sen used in the model. Tencepiece participle, pad mode uses post-padding mode, model input is in the form of len*batch, and some segment_ids and mask s are different from the ordinary model. Let's look at the code directly below.

 

Data is converted to tfrecord:

import  tensorflow  as tf
import sys
import six
import unicodedata
import sentencepiece as spm
import collections
from textclass import   FLAGS


SEG_ID_A   = 0
SEG_ID_B   = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4

special_symbols = {
    "<unk>"  : 0,
    "<s>"    : 1,
    "</s>"   : 2,
    "<cls>"  : 3,
    "<sep>"  : 4,
    "<pad>"  : 5,
    "<mask>" : 6,
    "<eod>"  : 7,
    "<eop>"  : 8,
}

VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]


sp = spm.SentencePieceProcessor()
sp.Load(FLAGS.spiece_model_file)

def _truncate_seq_pair(tokens_a, tokens_b, max_length):
  while True:
    total_length = len(tokens_a) + len(tokens_b)
    if total_length <= max_length:
      break
    if len(tokens_a) > len(tokens_b):
      tokens_a.pop()
    else:
      tokens_b.pop()

def get_class_ids(text,max_seq_length,tokenize_fn):
  texts = tokenize_fn(text)
  if len(texts) > max_seq_length - 2:
    texts = texts[:max_seq_length - 2]
  tokens = []
  segment_ids = []
  for token in texts:
    tokens.append(token)
    segment_ids.append(SEG_ID_A)
  tokens.append(SEP_ID)
  segment_ids.append(SEG_ID_A)

  tokens.append(CLS_ID)
  segment_ids.append(SEG_ID_CLS)

  input_ids = tokens
  input_mask = [0] * len(input_ids)
  if len(input_ids) < max_seq_length:
    delta_len = max_seq_length - len(input_ids)
    input_ids = [0] * delta_len + input_ids
    input_mask = [1] * delta_len + input_mask
    segment_ids = [SEG_ID_PAD] * delta_len + segment_ids

  assert len(input_ids) == max_seq_length
  assert len(input_mask) == max_seq_length
  assert len(segment_ids) == max_seq_length

  return   input_ids,input_mask,segment_ids


def get_pair_ids(text_a,text_b,max_seq_length,tokenize_fn):
  tokens_a = tokenize_fn(text_a)
  tokens_b = tokenize_fn(text_b)
  _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)

  tokens = []
  segment_ids = []
  for token in tokens_a:
    tokens.append(token)
    segment_ids.append(SEG_ID_A)
  tokens.append(SEP_ID)
  segment_ids.append(SEG_ID_A)

  for token in tokens_b:
    tokens.append(token)
    segment_ids.append(SEG_ID_B)
  tokens.append(SEP_ID)
  segment_ids.append(SEG_ID_B)

  tokens.append(CLS_ID)
  segment_ids.append(SEG_ID_CLS)

  input_ids = tokens
  input_mask = [0] * len(input_ids)

  if len(input_ids) < max_seq_length:
    delta_len = max_seq_length - len(input_ids)
    input_ids = [0] * delta_len + input_ids
    input_mask = [1] * delta_len + input_mask
    segment_ids = [SEG_ID_PAD] * delta_len + segment_ids

  assert len(input_ids) == max_seq_length
  assert len(input_mask) == max_seq_length
  assert len(segment_ids) == max_seq_length


  return input_ids,input_mask,segment_ids



SPIECE_UNDERLINE = '▁'
def encode_pieces(sp_model, text, return_unicode=True, sample=False):
  if six.PY2 and isinstance(text, unicode):
    text = text.encode('utf-8')

  if not sample:
    pieces = sp_model.EncodeAsPieces(text)
  else:
    pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
  new_pieces = []
  for piece in pieces:
    if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
      cur_pieces = sp_model.EncodeAsPieces(
          piece[:-1].replace(SPIECE_UNDERLINE, ''))
      if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
        if len(cur_pieces[0]) == 1:
          cur_pieces = cur_pieces[1:]
        else:
          cur_pieces[0] = cur_pieces[0][1:]
      cur_pieces.append(piece[-1])
      new_pieces.extend(cur_pieces)
    else:
      new_pieces.append(piece)

  # note(zhiliny): convert back to unicode for py2
  if six.PY2 and return_unicode:
    ret_pieces = []
    for piece in new_pieces:
      if isinstance(piece, str):
        piece = piece.decode('utf-8')
      ret_pieces.append(piece)
    new_pieces = ret_pieces

  return new_pieces



def encode_ids(sp_model, text, sample=False):
  pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
  ids = [sp_model.PieceToId(piece) for piece in pieces]
  return ids

def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
  if remove_space:
    outputs = ' '.join(inputs.strip().split())
  else:
    outputs = inputs
  outputs = outputs.replace("``", '"').replace("''", '"')

  if six.PY2 and isinstance(outputs, str):
    outputs = outputs.decode('utf-8')

  if not keep_accents:
    outputs = unicodedata.normalize('NFKD', outputs)
    outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
  if lower:
    outputs = outputs.lower()

  return outputs


def tokenize_fn(text):
    text = preprocess_text(text, lower=True)
    return encode_ids(sp, text)


def get_vocab(path):
    maps = collections.defaultdict()
    i = 0
    with tf.gfile.GFile(path, "r") as  f:
        for line in f.readlines():
            maps[line.strip()] = i
            i = i + 1
    f.close()
    return maps


def writedataclass(inputpath, vocab, outputpath,max_seq_length,tokenize_fn):
    eachonum = 5000
    num = 0
    recordfilenum = 0
    ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum)
    writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename)
    with  open(inputpath)  as f:
        for text in f.readlines():
            texts = text.split("\t")
            content= texts[0].lower().strip()
            label = vocab.get(texts[1].strip())
            num = num + 1
            input_ids,input_mask,segment_ids=get_class_ids(content, max_seq_length, tokenize_fn)
            if num > eachonum:
                num = 1
                recordfilenum = recordfilenum + 1
                ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum)
                writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename)

            example = tf.train.Example(
                features=tf.train.Features(
                    feature={'input_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=input_ids)),
                             'input_mask': tf.train.Feature(int64_list=tf.train.Int64List(value=input_mask)),
                             'segment_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=segment_ids)),
                             'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
                             }))
            serialized = example.SerializeToString()
            writer.write(serialized)
    writer.close()
    f.close()

I wrote a text categorization class. Look at it:
class  XlnetReadingClass(object):
    def __init__(self,model_config_path,is_training,FLAGS,input_ids,segment_ids,
                 input_mask,label,n_class):
        self.xlnet_config = xlnet.XLNetConfig(json_path=model_config_path)
        self.run_config = xlnet.create_run_config(is_training, True, FLAGS)
        self.input_ids=tf.transpose(input_ids,[1,0])
        self.segment_ids = tf.transpose(segment_ids, [1, 0])
        self.input_mask = tf.transpose(input_mask, [1, 0])

        self.model = xlnet.XLNetModel(
            xlnet_config=self.xlnet_config,
            run_config=self.run_config,
            input_ids=self.input_ids,
            seg_ids=self.segment_ids,
            input_mask=self.input_mask)

        cls_scope = FLAGS.cls_scope
        summary = self.model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
        self.per_example_loss, self.logits = modeling.classification_loss(
            hidden=summary,
            labels=label,
            n_class=n_class,
            initializer=self.model.get_initializer(),
            scope=cls_scope,
            return_logits=True)

        self.total_loss = tf.reduce_mean(self.per_example_loss)

        with tf.name_scope("train_op"):

            self.train_op, _, _ = model_utils.get_train_op(FLAGS, self.total_loss)

        with tf.name_scope("acc"):
            one_hot_target = tf.one_hot(label, n_class)
            self.acc=self.accuracy(self.logits,one_hot_target)

    def accuracy(self,logits, labels):
        arglabels_ = tf.argmax(tf.nn.softmax(logits), 1)
        arglabels = tf.argmax(tf.squeeze(labels), 1)
        acc = tf.to_float(tf.equal(arglabels_, arglabels))
        return tf.reduce_mean(acc)


def main(_):
    print('Loading config...')

    n_class = 38

    input_path = FLAGS.data_dir + "xlnetreading.tfrecords*"

    print("input_path:", input_path)
    files = tf.train.match_filenames_once(input_path)

    """
      inputs Is your data input path

    """
    input_ids, input_mask, segment_ids, label_ids = inputs(files, batch_size=FLAGS.batch_size, num_epochs=5,max_seq_length=FLAGS.max_seq_length)
    model_config_path=FLAGS.model_config_path
    is_training=False
    init_checkpoint = FLAGS.init_checkpoint


    model = XlnetReadingClass(model_config_path, is_training,FLAGS, input_ids
                    , segment_ids,input_mask, label_ids, n_class)

    tvars = tf.trainable_variables()

    if init_checkpoint:
        (assignment_map, initialized_variable_names) = model_utils.get_assignment_map_from_checkpoint(tvars,

                                                                                                   init_checkpoint)
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
        print("restore sucess  on cpu or gpu")

    session = tf.Session()
    session.run(tf.global_variables_initializer())
    session.run(tf.local_variables_initializer())

    print("**** Trainable Variables ****")
    for var in tvars:
        if var.name in initialized_variable_names:
            init_string = ", *INIT_FROM_CKPT*"
            print("name ={0}, shape = {1}{2}".format(var.name, var.shape,
                                                     init_string))

    print("xlnet reading class  model will start train .........")

    print(session.run(files))
    saver = tf.train.Saver()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=session)
    start_time = time.time()
    for i in range(8000):
        _, loss_train, acc = session.run([model.train_op, model.total_loss, model.acc])
        if i % 100 == 0:
            end_time = time.time()
            time_dif = end_time - start_time
            time_dif = timedelta(seconds=int(round(time_dif)))
            msg = 'Iter: {0:>6}, Train Loss: {1:>6.2},' \
                  + '  Cost: {2}  Time:{3}  acc:{4}'
            print(msg.format(i, loss_train, time_dif, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), acc))
            start_time = time.time()
        if i % 500 == 0 and i > 0:
            saver.save(session, "../exp/reading/model.ckpt", global_step=i)
    coord.request_stop()
    coord.join(threads)
    session.close()

 

Posted by kashmirekat on Tue, 01 Oct 2019 09:03:56 -0700