framework

tensorflow常用模板

李勐
本文记录如何使用tensorflow实现基于transformer的轨迹分类。 在技术实现上包含以下关键细节需要注意: 轨迹点如果超出模型序列长度则进行中间截断,如果不足则进行中间补0 轨迹数据需要对单条数据进行归一化 只使用transformer的encoder部分进行建模,不使用embedding或positional embedding机制,这种方式工程实现简单,但模型可能无法捕获序列前后的区域信息 代码实现 首先实现transformer encoder,内部结构为self-attention和feed forward网络。 def scaled_dot_product_attention(q, k, v, mask=None): matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) dk = tf.cast(tf.shape(k)[-1], tf.float32) scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) if mask is not None: scaled_attention_logits += (mask * -1e9) attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k) output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) return output, attention_weights class MultiHeadAttention(tf.keras.layers.Layer): def __init__(self, d_model, num_heads, dropout): super(MultiHeadAttention, self).