<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Deep-Learning on NEVERMORE</title><link>https://hych0317.github.io/hugoweb_auto/categories/deep-learning/</link><description>Recent content in Deep-Learning on NEVERMORE</description><generator>Hugo -- gohugo.io</generator><language>zh-cn</language><copyright>ychuang</copyright><lastBuildDate>Sat, 26 Oct 2024 15:34:07 +0800</lastBuildDate><atom:link href="https://hych0317.github.io/hugoweb_auto/categories/deep-learning/index.xml" rel="self" type="application/rss+xml"/><item><title>Transformers</title><link>https://hych0317.github.io/hugoweb_auto/p/transformers/</link><pubDate>Sat, 26 Oct 2024 15:34:07 +0800</pubDate><guid>https://hych0317.github.io/hugoweb_auto/p/transformers/</guid><description>&lt;h2 id="基础部件">基础部件
&lt;/h2>&lt;p>基本流程：
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/transformers/steps.png"
loading="lazy"
alt="steps"
>&lt;/p>
&lt;h3 id="tokenizer">tokenizer
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">transformers&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">AutoTokenizer&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 加载&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tokenizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Autotokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;uer/roberta-base-finetuned-dianping-chinese&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">trust_remote_code&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 从hf加载&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;./my_tokenizer&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 保存到本地&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tokenizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoTokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;./my_tokenizer&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 从本地加载&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 分词&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tokens&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tokenize&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;你好，欢迎使用！&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">vocab&lt;/span>&lt;span class="c1"># 查看词表&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tokernizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">vocab_size&lt;/span>&lt;span class="c1"># 查看词表大小&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">ids0&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">convert_tokens_to_ids&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tokens&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 转换成id&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 可通过tokenizer.convert_ids_to_tokens(ids)转换回来，convert_tokens_to_string(tokens)转换回句子&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">ids1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">encode&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;你好，欢迎使用！&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">add_special_tokens&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">False&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 一步到位&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">str1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">decode&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">ids1&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 解码&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 填充和截断，以适应batch长度&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">input_ids&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">encode&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;你好，欢迎使用！&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">max_length&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">padding&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;max_length&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">truncation&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="model">model
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">transformers&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">AutoModel&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 加载&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoModel&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;uer/roberta-base-finetuned-dianping-chinese&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">trust_remote_code&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 从hf加载&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoModel&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;./model_name&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">output_attentions&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 从本地加载&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">inputs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tokenizer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;你好，欢迎使用！&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">return_tensors&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;pt&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">outputs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">**&lt;/span>&lt;span class="n">inputs&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">outputs&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">last_hidden_state&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 输出最后一层隐藏层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 指定model head&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">transformers&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">AutoModelForSequenceClassification&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">cls_model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoModelForSequenceClassification&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;uer/roberta-base-finetuned-dianping-chinese&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">num_labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">output_attentions&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 指定三分类&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 对基本模型的输出进行任务处理&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">output_cls&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">cls_model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">**&lt;/span>&lt;span class="n">inputs&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="dataset">dataset
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 加载数据集&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dataset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">load_dataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;dataset_name&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="s2">&amp;#34;subtask_name&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">split&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;train[10:100]&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># (数据集名,[可选]子任务名(有的话),[可选]切片)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 查看&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dataset&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;train&amp;#34;&lt;/span>&lt;span class="p">][:&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 划分&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dataset&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train_test_split&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">test_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">stratify_by_column&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;label&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 按比例划分,label分布均衡&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 数据选取与过滤&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">datasets&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;train&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">select&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">])&lt;/span>&lt;span class="c1"># 选取第2和第6条数据,返回的类型仍是dataset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">filter_dataset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">datasets&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;train&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">filter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="k">lambda&lt;/span> &lt;span class="n">example&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="s2">&amp;#34;中国&amp;#34;&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">example&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;title&amp;#34;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 数据映射&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">processed_datasets&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">datasets&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">map&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">preprocess_function&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">batched&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">remove_columns&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;text&amp;#34;&lt;/span>&lt;span class="p">])&lt;/span>&lt;span class="c1"># 映射函数,&amp;lt;可选&amp;gt;使用batch处理,去除text列&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 本地保存与加载&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">processed_datasets&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save_to_disk&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;./processed_data&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">processed_datasets&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">load_from_disk&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;./processed_data&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="实例">实例
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;span class="lnt">38
&lt;/span>&lt;span class="lnt">39
&lt;/span>&lt;span class="lnt">40
&lt;/span>&lt;span class="lnt">41
&lt;/span>&lt;span class="lnt">42
&lt;/span>&lt;span class="lnt">43
&lt;/span>&lt;span class="lnt">44
&lt;/span>&lt;span class="lnt">45
&lt;/span>&lt;span class="lnt">46
&lt;/span>&lt;span class="lnt">47
&lt;/span>&lt;span class="lnt">48
&lt;/span>&lt;span class="lnt">49
&lt;/span>&lt;span class="lnt">50
&lt;/span>&lt;span class="lnt">51
&lt;/span>&lt;span class="lnt">52
&lt;/span>&lt;span class="lnt">53
&lt;/span>&lt;span class="lnt">54
&lt;/span>&lt;span class="lnt">55
&lt;/span>&lt;span class="lnt">56
&lt;/span>&lt;span class="lnt">57
&lt;/span>&lt;span class="lnt">58
&lt;/span>&lt;span class="lnt">59
&lt;/span>&lt;span class="lnt">60
&lt;/span>&lt;span class="lnt">61
&lt;/span>&lt;span class="lnt">62
&lt;/span>&lt;span class="lnt">63
&lt;/span>&lt;span class="lnt">64
&lt;/span>&lt;span class="lnt">65
&lt;/span>&lt;span class="lnt">66
&lt;/span>&lt;span class="lnt">67
&lt;/span>&lt;span class="lnt">68
&lt;/span>&lt;span class="lnt">69
&lt;/span>&lt;span class="lnt">70
&lt;/span>&lt;span class="lnt">71
&lt;/span>&lt;span class="lnt">72
&lt;/span>&lt;span class="lnt">73
&lt;/span>&lt;span class="lnt">74
&lt;/span>&lt;span class="lnt">75
&lt;/span>&lt;span class="lnt">76
&lt;/span>&lt;span class="lnt">77
&lt;/span>&lt;span class="lnt">78
&lt;/span>&lt;span class="lnt">79
&lt;/span>&lt;span class="lnt">80
&lt;/span>&lt;span class="lnt">81
&lt;/span>&lt;span class="lnt">82
&lt;/span>&lt;span class="lnt">83
&lt;/span>&lt;span class="lnt">84
&lt;/span>&lt;span class="lnt">85
&lt;/span>&lt;span class="lnt">86
&lt;/span>&lt;span class="lnt">87
&lt;/span>&lt;span class="lnt">88
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils.data&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">Dataset&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">DataLoader&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MyDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Dataset&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">pd&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">read_csv&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;./ChnSentiCorp_htl_all.csv&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropna&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__getitem__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">index&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">iloc&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">index&lt;/span>&lt;span class="p">][&lt;/span>&lt;span class="s2">&amp;#34;review&amp;#34;&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">iloc&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">index&lt;/span>&lt;span class="p">][&lt;/span>&lt;span class="s2">&amp;#34;label&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__len__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 加载数据集&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dataset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MyDataset&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 划分数据集&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils.data&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">random_split&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">trainset&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">validset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">random_split&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dataset&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">lengths&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mf">0.9&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.1&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 定义dataloader&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tokenizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoTokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;hfl/rbt3&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">collate_func&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">batch&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">texts&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">[],&lt;/span> &lt;span class="p">[]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">item&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">batch&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">texts&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">append&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">labels&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">append&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">inputs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tokenizer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">texts&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">max_length&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">128&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">padding&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;max_length&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">truncation&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">return_tensors&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;pt&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">inputs&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;labels&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">inputs&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils.data&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">DataLoader&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">trainloader&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">trainset&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">collate_fn&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">collate_func&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># shuffle=True表示每个epoch打乱顺序&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">validloader&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">validset&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">64&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">False&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">collate_fn&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">collate_func&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># optimizer&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.optim&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">Adam&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoModelForSequenceClassification&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;hfl/rbt3&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">if&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">is_available&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Adam&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">2e-5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 训练/评估&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">evaluate&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">acc_num&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">with&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">inference_mode&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">batch&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">validloader&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">is_available&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">batch&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="n">k&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">v&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="n">k&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">v&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">batch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">items&lt;/span>&lt;span class="p">()}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">**&lt;/span>&lt;span class="n">batch&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">pred&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">argmax&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">logits&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dim&lt;/span>&lt;span class="o">=-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">acc_num&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">pred&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">long&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="n">batch&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;labels&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">long&lt;/span>&lt;span class="p">())&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">float&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">acc_num&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">validset&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">train&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">epoch&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">log_step&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">100&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">global_step&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">ep&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">epoch&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">batch&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">trainloader&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="c1"># 训练集&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">is_available&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">batch&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="n">k&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">v&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="n">k&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">v&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">batch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">items&lt;/span>&lt;span class="p">()}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 把batch的key和value都转到cuda上&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zero_grad&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 清空梯度&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">**&lt;/span>&lt;span class="n">batch&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 反向传播&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">step&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">global_step&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="n">log_step&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s2">&amp;#34;ep: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">ep&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">, global_step: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">global_step&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">, loss: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">global_step&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="mi">1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">acc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">evaluate&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s2">&amp;#34;ep: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">ep&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">, acc: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">acc&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">train&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">sen&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="s2">&amp;#34;我觉得这家酒店不错，饭很好吃！&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">id2_label&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="s2">&amp;#34;差评！&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="s2">&amp;#34;好评！&amp;#34;&lt;/span>&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">with&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">inference_mode&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">inputs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tokenizer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sen&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">return_tensors&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;pt&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">inputs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="n">k&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">v&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="n">k&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">v&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">inputs&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">items&lt;/span>&lt;span class="p">()}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">logits&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">**&lt;/span>&lt;span class="n">inputs&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">logits&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">pred&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">argmax&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">logits&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dim&lt;/span>&lt;span class="o">=-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s2">&amp;#34;输入：&lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">sen&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="se">\n&lt;/span>&lt;span class="s2">模型预测结果:&lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">id2_label&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">get&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">pred&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">())&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="使用trainer优化实例">使用trainer优化实例
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;span class="lnt">38
&lt;/span>&lt;span class="lnt">39
&lt;/span>&lt;span class="lnt">40
&lt;/span>&lt;span class="lnt">41
&lt;/span>&lt;span class="lnt">42
&lt;/span>&lt;span class="lnt">43
&lt;/span>&lt;span class="lnt">44
&lt;/span>&lt;span class="lnt">45
&lt;/span>&lt;span class="lnt">46
&lt;/span>&lt;span class="lnt">47
&lt;/span>&lt;span class="lnt">48
&lt;/span>&lt;span class="lnt">49
&lt;/span>&lt;span class="lnt">50
&lt;/span>&lt;span class="lnt">51
&lt;/span>&lt;span class="lnt">52
&lt;/span>&lt;span class="lnt">53
&lt;/span>&lt;span class="lnt">54
&lt;/span>&lt;span class="lnt">55
&lt;/span>&lt;span class="lnt">56
&lt;/span>&lt;span class="lnt">57
&lt;/span>&lt;span class="lnt">58
&lt;/span>&lt;span class="lnt">59
&lt;/span>&lt;span class="lnt">60
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">transformers&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">AutoTokenizer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">AutoModelForSequenceClassification&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Trainer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">TrainingArguments&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">datasets&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">load_dataset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 加载/划分数据集&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dataset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">load_dataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;csv&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">data_files&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;./ChnSentiCorp_htl_all.csv&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">split&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;train&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dataset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">dataset&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">filter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="k">lambda&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;review&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="ow">is&lt;/span> &lt;span class="ow">not&lt;/span> &lt;span class="kc">None&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">datasets&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">dataset&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train_test_split&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">test_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 处理数据集&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tokenizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoTokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;hfl/rbt3&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">process_function&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">examples&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tokenized_examples&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tokenizer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">examples&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;review&amp;#34;&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">max_length&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">128&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">truncation&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tokenized_examples&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;labels&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">examples&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;label&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">tokenized_examples&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tokenized_datasets&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">datasets&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">map&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">process_function&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batched&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">remove_columns&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">datasets&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;train&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">column_names&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 模型\评估&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoModelForSequenceClassification&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;hfl/rbt3&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">evaluate&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">acc_metric&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">evaluate&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;accuracy&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">f1_metric&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">evaluate&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;f1&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">eval_metric&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">eval_predict&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">predictions&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">eval_predict&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">predictions&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">predictions&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">argmax&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">axis&lt;/span>&lt;span class="o">=-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">acc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">acc_metric&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">compute&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">predictions&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">predictions&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">references&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">f1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">f1_metric&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">compute&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">predictions&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">predictions&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">references&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">acc&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">update&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">f1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">acc&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 创建training_arguments&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">train_args&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">TrainingArguments&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output_dir&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;./checkpoints&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1"># 输出文件夹&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">per_device_train_batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">64&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1"># 训练时的batch_size&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">per_device_eval_batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">128&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1"># 验证时的batch_size&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">logging_steps&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1"># log 打印的频率&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">evaluation_strategy&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;epoch&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1"># 评估策略&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">save_strategy&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;epoch&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1"># 保存策略&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">save_total_limit&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1"># 最大保存数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">learning_rate&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">2e-5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1"># 学习率&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">weight_decay&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.01&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1"># weight_decay&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">metric_for_best_model&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;f1&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1"># 设定评估指标&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">load_best_model_at_end&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 训练完成后加载最优模型&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 创建trainer&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">transformers&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">DataCollatorWithPadding&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">trainer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Trainer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">args&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">train_args&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">train_dataset&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">tokenized_datasets&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;train&amp;#34;&lt;/span>&lt;span class="p">],&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">eval_dataset&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">tokenized_datasets&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;test&amp;#34;&lt;/span>&lt;span class="p">],&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">data_collator&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">DataCollatorWithPadding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tokenizer&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">tokenizer&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">compute_metrics&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">eval_metric&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 训练/评估&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">trainer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">trainer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">evaluate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tokenized_datasets&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;test&amp;#34;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">trainer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">predict&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tokenized_datasets&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;test&amp;#34;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="nlp任务实操">NLP任务实操
&lt;/h2>&lt;h3 id="命名实体识别ner">命名实体识别(NER)
&lt;/h3>&lt;p>NER是指识别文本中的实体，如人名、地名、机构名等。&lt;br>
通常，NER任务包括两部分:&lt;/p>
&lt;ul>
&lt;li>实体识别: 识别出文本中的实体，并给予其相应的标签。&lt;/li>
&lt;li>实体分类: 将识别出的实体进行分类，如人名、地名、机构名等。&lt;/li>
&lt;/ul>
&lt;h2 id="peft微调">PEFT微调
&lt;/h2>&lt;p>在创建模型后设置tuning_config,随后model = get_peft_model(model, config)&lt;/p>
&lt;blockquote>
&lt;p>常见高效微调方法综述见arXiv:2303.15647&lt;/p>
&lt;/blockquote>
&lt;h3 id="prompt-tuning">Prompt tuning
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">peft&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">PromptTuningConfig&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">get_peft_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">TaskType&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">PromptTuningInit&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Hard Prompt&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">PromptTuningConfig&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">task_type&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">TaskType&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CAUSAL_LM&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">prompt_tuning_init&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">PromptTuningInit&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">TEXT&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">prompt_tuning_init_text&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;下面是一段人与机器人的对话。&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_virtual_tokens&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tokenizer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;下面是一段人与机器人的对话。&amp;#34;&lt;/span>&lt;span class="p">)[&lt;/span>&lt;span class="s2">&amp;#34;input_ids&amp;#34;&lt;/span>&lt;span class="p">]),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tokenizer_name_or_path&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;Langboat/bloom-1b4-zh&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">get_peft_model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">config&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 进行训练...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 加载训练完的模型&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">peft&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">PeftModel&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoModelForCausalLM&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;Langboat/bloom-1b4-zh&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 原模型&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">peft_model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">PeftModel&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model_id&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;./output/checkpoint-500/&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="p-tuningprefix-tuning">P-tuning/Prefix tuning
&lt;/h3>&lt;p>P-tuning把prompt加在输入embedding层的前缀，而Prefix tuning将kv值作为前缀加在模型的每一层前，而不仅仅是输入层。&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/transformers/prefix.png"
loading="lazy"
alt="prefix tuning"
>&lt;/p>
&lt;p>原理(类似kv缓存的思想):
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/transformers/prefix_kvcache.png"
loading="lazy"
alt="prefix tuning"
>
因为对于扩展后的KV矩阵，Qm*n,K(m+x)*n,V(m+x)*n而言,Q·KT得m*(m+k)维矩阵，再乘V得m*n维矩阵，和原矩阵相乘维度一样。&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">peft&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">PrefixTuningConfig&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">get_peft_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">TaskType&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">PrefixTuningConfig&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">task_type&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">TaskType&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CAUSAL_LM&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_virtual_tokens&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">prefix_projection&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># prefix_projection默认值为false，表示使用P-Tuning v2， 如果为true，则表示使用 Prefix Tuning&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 其余流程一致&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="lora">Lora
&lt;/h3>&lt;p>通过矩阵分解的方式，将原始权重分解为低秩矩阵，计算时仅优化低秩矩阵，最后把低秩矩阵相乘加到原始权重上作为微调结果。&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">peft&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">LoraConfig&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">TaskType&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">get_peft_model&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 查看target_modules参数要分解的权重层,该参数课传入列表如:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># [&amp;#34;word_embeddings&amp;#34;, &amp;#34;encoder.layer.0.attention.self.query&amp;#34;, &amp;#34;encoder.layer.0.attention.self.key&amp;#34;, &amp;#34;encoder.layer.0.attention.self.value&amp;#34;]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 也可以传入正则表达式如下&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">name&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">parameter&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">named_parameters&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">name&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">LoraConfig&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">task_type&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">TaskType&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CAUSAL_LM&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">target_modules&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;.*\.1.*query_key_value&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">modules_to_save&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;word_embeddings&amp;#34;&lt;/span>&lt;span class="p">])&lt;/span>&lt;span class="c1"># modules_to_save表示其它要参与训练的权重层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">get_peft_model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">config&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 进行训练...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 加载训练完的模型&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">peft&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">PeftModel&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoModelForCausalLM&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;Langboat/bloom-1b4-zh&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tokenizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoTokenizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;Langboat/bloom-1b4-zh&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">peft_model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">PeftModel&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model_id&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;./output/checkpoint-500/&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 合并模型&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># peft_model和merge_model的权重相同，p的预训练模型和LoRA微调权重是分开的,LoRA权重在推理时动态加载;而m是成为一个新的完全体模型&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">merge_model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">peft_model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">merge_and_unload&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">merge_model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;./output/merge_model&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 保存模型&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="ia3">IA3
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 仅记录调用方法&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">peft&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">IA3Config&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">TaskType&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">get_peft_model&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">IA3Config&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">task_type&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">TaskType&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CAUSAL_LM&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="使用不同适配器">使用不同适配器
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;span class="lnt">38
&lt;/span>&lt;span class="lnt">39
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">peft&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">LoraConfig&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">get_peft_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">PeftModel&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 对层0进行Lora微调&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">LoraConfig&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">target_modules&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;0&amp;#34;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">get_peft_model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">config1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model1&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;./loraA&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 对层2进行Lora微调&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">LoraConfig&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">target_modules&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;2&amp;#34;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">get_peft_model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">config2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model2&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;./loraB&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 此时model2会显示层0,层2都被lora,因为net1会记录被A调整的部分&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 但是!!!经验证,实际上loraB只记录了层2的权重调整,因为model2的输入是net1+loraA,输出是net1+loraA+loraB,所以loraB只记录了层2的权重&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">)&lt;/span>&lt;span class="c1"># 上面的net1被使用后调整了,重新定义原网络&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 使用原网络和保存的loraA参数得到PeftModel&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model3&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">PeftModel&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model_id&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;./loraA/&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">adapter_name&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;loraA&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 此时的模型是net1+loraA(层0的适配器参数)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model3&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">active_adapter&lt;/span>&lt;span class="c1"># 显示当前激活的适配器A&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 改用loraB参数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model3&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load_adapter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;./loraB/&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">adapter_name&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;loraB&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 加载loraB,实际模型结构是net1+loraA+loraB,激活的结构是net1+loraA(还没切换)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model3&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">set_adapter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;loraB&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 切换到loraB,loraA被禁用,模型激活结构变为net1+loraB&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model3&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">active_adapter&lt;/span>&lt;span class="c1"># 显示当前激活的适配器B&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">with&lt;/span> &lt;span class="n">model3&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">disable_adapter&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">code&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="c1"># 需要使用with语句关闭适配器&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="低精度训练">低精度训练
&lt;/h2>&lt;p>默认单精度fp32,每个参数占4Byte.半精度即fp16(更推荐bf16),每个参数占2Byte.&lt;/p>
&lt;h3 id="半精度训练实例">半精度训练实例
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoModelForCausalLM&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;&amp;lt;model name&amp;gt;&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">low_cpu_mem_usage&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">torch_dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bfloat16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">device_map&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;auto&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 半精度训练&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 建议加载时用&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">half&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 在fine tuning后把调整的参数也转成半精度&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="量化">量化
&lt;/h3>&lt;p>显存占用变少,但是训练推理速度变慢.&lt;/p>
&lt;p>INT8 量化即将浮点数$x_f$通过缩放因子scale映射到范围在[-128, 127] 内,用8bit表示即
[x_q = Clip(Round(x_f*scale))]
其中scale=127/浮点数绝对值最大值;Round是四舍五入;&lt;br>
数据中离群值(与其它数值相差很大)的存在会导致丢失很多信息,使用Clip将离群值限制在[-128, 127]范围内.&lt;/p>
&lt;p>反量化的过程为:
[x_f = x_q/scale]&lt;/p>
&lt;p>因此可以采取混合精度量化:&lt;br>
将包含了Emergent Features的几个维度从矩阵中分离出来，对其做高精度的矩阵乘法；其余数值接近的部分进行量化&lt;/p>
&lt;h3 id="8bit4bit量化与qlora模型训练">8bit,4bit量化与QLoRA模型训练
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoModelForCausalLM&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;D:/Pretrained_models/modelscope/Llama-2-7b-ms&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">low_cpu_mem_usage&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">torch_dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bfloat16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">device_map&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;auto&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">load_in_8bit&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">AutoModelForCausalLM&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_pretrained&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;D:/Pretrained_models/modelscope/Llama-2-13b-ms&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">low_cpu_mem_usage&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">torch_dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bfloat16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">device_map&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;auto&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">load_in_4bit&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">bnb_4bit_compute_dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bfloat16&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">bnb_4bit_quant_type&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;nf4&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">bnb_4bit_use_double_quant&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 启用nf4量化,启用双重量化&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="分布式训练">分布式训练
&lt;/h2>&lt;h3 id="各类并行">各类并行
&lt;/h3>&lt;p>data parallel: 每个GPU加载完整的模型,训练的数据不同&lt;br>
pipeline parallel: 每个GPU加载模型不同的层&lt;br>
tensor parallel: 把同一层的各部分参数拆分到各个GPU上&lt;/p>
&lt;p>3D并行:
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/transformers/3Dpara.png"
loading="lazy"
alt="3Dpara"
>
图中:2(数据并行)*4(流水并行|横向箭头,代表不同层)*4(张量并行|竖向箭头,同层的不同参数)=32GPUs&lt;br>
解释:模型32层,每8层分成一个流水并行块;每个流水并行块分成4个张量并行块,每个张量并行块有4个GPU,共16个GPU;再乘以2行数据并行=32GPUs&lt;/p>
&lt;h3 id="distributed-dataparallel">Distributed DataParallel
&lt;/h3>&lt;p>&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/transformers/datapara.png"
loading="lazy"
alt="datapara"
>&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 指定使用GPU 0, 1和2（不设置device_ids或令其=None，则默认使用所有GPU）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">DataParallel&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">device_ids&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 在训练时，需要对loss进行mean()，因为loss需要是标量才可以进行反向传播&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="accelerater">Accelerater
&lt;/h3></description></item><item><title>Machine Learning</title><link>https://hych0317.github.io/hugoweb_auto/p/machine-learning/</link><pubDate>Mon, 21 Oct 2024 15:34:07 +0800</pubDate><guid>https://hych0317.github.io/hugoweb_auto/p/machine-learning/</guid><description>&lt;h2 id="基础概念">基础概念
&lt;/h2>&lt;h3 id="张量tensor">张量（Tensor）
&lt;/h3>&lt;p>张量（tensor）是指具有多个维度的数组，它可以用来表示向量、矩阵、高阶数组等多种数据结构。张量的元素可以是标量、向量、矩阵、张量等。&lt;/p>
&lt;h3 id="损失函数">损失函数
&lt;/h3>&lt;p>损失函数（loss function）是指用来衡量模型预测值与真实值之间的差距，并反映模型的预测精度的函数。损失函数的选择对模型的训练、优化和泛化能力都有着至关重要的影响。常见的损失函数有：&lt;/p>
&lt;h4 id="残差平方和residual-sum-of-squares-rss">残差平方和(residual sum of squares, RSS)
&lt;/h4>&lt;p>公式：$L(y, \hat{y}) = \sum_{i=1}^n (y_i - \hat{y}_i)^2$&lt;/p>
&lt;h3 id="梯度下降法">梯度下降法
&lt;/h3>&lt;p>梯度下降法（gradient descent）是一种优化算法，它通过迭代的方式不断更新模型的参数，使得损失函数的值逐渐减小。梯度下降法的基本思想是：沿着损失函数的负梯度方向更新参数，使得损失函数的值减小。&lt;/p>
&lt;p>例:softmax计算梯度:
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/Machine_Learning/softmax_grad.png"
loading="lazy"
alt="softmax"
>&lt;/p>
&lt;h3 id="反向传播">反向传播
&lt;/h3>&lt;p>反向传播（backpropagation）是指通过计算梯度来更新模型参数的算法。反向传播算法的基本思想是：从输出层开始，沿着损失函数的梯度方向更新参数，直到更新到网络的输入层。
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/Machine_Learning/back1.png"
loading="lazy"
alt="pic1"
>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/Machine_Learning/back2.png"
loading="lazy"
alt="pic2"
>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/Machine_Learning/back3.png"
loading="lazy"
alt="pic3"
>&lt;/p>
&lt;h2 id="神经网络">神经网络
&lt;/h2>&lt;h3 id="cnn">CNN
&lt;/h3>&lt;p>网络结构:
卷积核 -&amp;gt; 激活函数 -&amp;gt; 池化层 -&amp;gt; 全连接层 -&amp;gt; 激活函数 -&amp;gt; 输出层
使用卷积核提取图像特征，通过激活函数对特征进行非线性变换，通过池化层对特征进行降维，再通过全连接层进行分类。&lt;/p>
&lt;h3 id="rnn">RNN
&lt;/h3>&lt;p>在激活函数的输出重新连接到网络的输入，使得网络能够记住之前的输入，并对当前输入做出更好的预测。但是，这条路径的权重会导致梯度消失(w&amp;lt;1)和梯度爆炸(w&amp;gt;1)的问题。由此提出了LSTM&lt;/p>
&lt;h3 id="lstm">LSTM
&lt;/h3>&lt;p>遗忘门、输入门、输出门：
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/Machine_Learning/forgetgate.png"
loading="lazy"
alt="LSTM1"
>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/Machine_Learning/inputgate.png"
loading="lazy"
alt="LSTM2"
>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/Machine_Learning/outputgate.png"
loading="lazy"
alt="LSTM3"
>&lt;/p>
&lt;h3 id="gan">GAN
&lt;/h3>&lt;h3 id="transformer">Transformer
&lt;/h3>&lt;h4 id="embedding">Embedding
&lt;/h4>&lt;p>词嵌入（embedding）是指将词语转换为固定维度的向量表示的过程。词嵌入可以提高文本分类、文本匹配、文本聚类等任务的性能。常见的词嵌入方法有词向量、词袋模型、GloVe、BERT等。&lt;/p>
&lt;p>位置编码:给每个输入值生成特定的位置值序列,保证语序
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/Machine_Learning/pos_encoding.png"
loading="lazy"
alt="position encoding"
>&lt;/p>
&lt;h4 id="self-attention">Self Attention
&lt;/h4>&lt;p>自注意力机制（self-attention）是指模型通过注意力机制来获取输入序列的全局信息。这有助于为每个输入提供上下文信息，并建立输入间的联系。&lt;br>
自注意力机制使每个token计算与其它token间的相似度(Query Key)，从而得知如it指代的是哪个词之类的信息。&lt;br>
注意,每个token的Q\K\V权重系数都是相同的。&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/Machine_Learning/selfattention.png"
loading="lazy"
alt="self attention"
>&lt;/p>
&lt;p>将自注意力模块的输出通过softmax函数,得到每个token的权重,再将权重与value序列相乘,得到最终的嵌入向量。
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/Machine_Learning/transformer1.png"
loading="lazy"
alt="self attention2"
>&lt;/p>
&lt;h4 id="encoder-decoder-attention">Encoder-decoder Attention
&lt;/h4></description></item><item><title>Pytorch</title><link>https://hych0317.github.io/hugoweb_auto/p/pytorch/</link><pubDate>Mon, 21 Oct 2024 15:34:07 +0800</pubDate><guid>https://hych0317.github.io/hugoweb_auto/p/pytorch/</guid><description>&lt;h2 id="常用概念">常用概念
&lt;/h2>&lt;h3 id="设备相关">设备相关
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">device&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;cuda&amp;#34;&lt;/span> &lt;span class="k">if&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">is_available&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="k">else&lt;/span> &lt;span class="s2">&amp;#34;cpu&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 创建设备&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">a&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ones&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">b&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">a&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">clone&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">c&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">a&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">detach&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 只想使用当前的值,为常数,避免导致原张量梯度更新&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">d&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">a&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># &amp;#34;cuda&amp;#34; / &amp;#34;cpu&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="流程简述">流程简述
&lt;/h3>&lt;p>具体网络详解见第三部分&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;span class="lnt">38
&lt;/span>&lt;span class="lnt">39
&lt;/span>&lt;span class="lnt">40
&lt;/span>&lt;span class="lnt">41
&lt;/span>&lt;span class="lnt">42
&lt;/span>&lt;span class="lnt">43
&lt;/span>&lt;span class="lnt">44
&lt;/span>&lt;span class="lnt">45
&lt;/span>&lt;span class="lnt">46
&lt;/span>&lt;span class="lnt">47
&lt;/span>&lt;span class="lnt">48
&lt;/span>&lt;span class="lnt">49
&lt;/span>&lt;span class="lnt">50
&lt;/span>&lt;span class="lnt">51
&lt;/span>&lt;span class="lnt">52
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 1.网络定义&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.optim&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">optim&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 定义一个简单的全连接神经网络&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">SimpleNN&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">SimpleNN&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 输入层到隐藏层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 隐藏层到输出层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># ReLU 激活函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 2.创建网络实例&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">SimpleNN&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 3. 定义损失函数和优化器&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">criterion&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">MSELoss&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 均方误差损失函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Adam&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.001&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># Adam 优化器&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 4. 假设有训练数据 X 和 Y&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 10 个样本，2 个特征&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Y&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 10 个目标值&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 5. 训练循环&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">epoch&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">100&lt;/span>&lt;span class="p">):&lt;/span> &lt;span class="c1"># 训练 100 轮&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zero_grad&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 清空之前的梯度&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 前向传播&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># output可通过激活函数完成对应任务&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># import torch.nn.functional as F &lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># # ReLU 激活&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># output = F.relu(input_tensor)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># # Sigmoid 激活&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># output = torch.sigmoid(input_tensor)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># # Tanh 激活&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># output = torch.tanh(input_tensor)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">criterion&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Y&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 计算损失&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 反向传播&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">step&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 更新参数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 每 10 轮输出一次损失&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">epoch&lt;/span>&lt;span class="o">+&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="mi">10&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s1">&amp;#39;Epoch [&lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">epoch&lt;/span>&lt;span class="o">+&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">/100], Loss: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s1">.4f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 6.评估&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 设置模型为评估模式&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">with&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">no_grad&lt;/span>&lt;span class="p">():&lt;/span> &lt;span class="c1"># 在评估过程中禁用梯度计算&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">X_test&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">criterion&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Y_test&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s1">&amp;#39;Test Loss: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s1">.4f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="张量操作">张量操作
&lt;/h2>&lt;h3 id="张量基本操作">张量基本操作
&lt;/h3>&lt;h4 id="基本属性">基本属性
&lt;/h4>&lt;p>张量tensor：&lt;br>
属性：维度，形状，数据类型&lt;br>
0维即单个数字，一维即一维数组&lt;br>
形状指每个维度的大小，如(3,4)表示三行四列&lt;/p>
&lt;p>另还有维度数.dim(),启用梯度计算.requires_grad,获取元素总数.numel()等&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 1. 创建张量&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([[&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">],[&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">6&lt;/span>&lt;span class="p">]])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x0&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zeros&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ones&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">xr&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># rand随机，randn服从正态分布&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">device&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;cuda&amp;#34;&lt;/span> &lt;span class="k">if&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">is_available&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="k">else&lt;/span> &lt;span class="s2">&amp;#34;cpu&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">xd&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">device&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">arange&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">12&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">float32&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">reshape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 指定个数\类型\形状&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">tensor_3d&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">stack&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="n">xd&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">xd&lt;/span>&lt;span class="o">+&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">xd&lt;/span>&lt;span class="o">+&lt;/span>&lt;span class="mi">5&lt;/span>&lt;span class="p">])&lt;/span>&lt;span class="c1"># 3维张量即三个二维的堆叠，用stack&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 2. 改变形状&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">xrs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">reshape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">y&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 3. 查看属性&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">numel&lt;/span>&lt;span class="p">())&lt;/span>&lt;span class="c1"># number of elements元素总数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 所有元素的和,产生单元素张量&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 4. 索引&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="c1"># 索引从0开始,区间左闭右开,1:3即1、2,对应第二到第三行&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="c1"># 第二列&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="c1"># 第二到第三行,第二到第三列&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="n">Y&lt;/span> &lt;span class="c1"># 元素比较,生成布尔张量&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="c1"># 元素过滤&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 5. 张量连结&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cat&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Y&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">dim&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 按行连接&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cat&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Y&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">dim&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 按列连接&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 6. 广播机制&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 由于`a`和`b`分别是3*1矩阵和1*2矩阵，如果让它们相加，它们的形状不匹配。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 将两个矩阵*广播*为一个更大的3*2矩阵，矩阵`a`将复制列，矩阵`b`将复制行，然后再按元素相加&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">a&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([[&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">]])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">b&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">a&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="n">b&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 输出 tensor([[5, 6],[6, 7],[7, 8]])&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h4 id="张量操作-1">张量操作
&lt;/h4>&lt;pre>&lt;code>torch.matmul(x, y) 矩阵乘法
torch.dot(x, y) 向量点积（仅适用于 1D 张量）
torch.sum(x) 求和
torch.mean(x) 求均值
torch.max(x) 求最大值
torch.min(x) 求最小值
torch.argmax(x, dim) 返回最大值的索引（指定维度）
torch.softmax(x, dim) 计算softmax（指定维度）
&lt;/code>&lt;/pre>
&lt;h4 id="形状操作">形状操作
&lt;/h4>&lt;pre>&lt;code>x.view(shape) 改变张量的形状（不改变数据）
x.reshape(shape) 类似于 view，但更灵活
x.t() 转置矩阵
x.unsqueeze(dim) 在指定维度添加一个维度，如x从[N]变成[N,1]
x.squeeze(dim) 去掉指定维度为 1 的维度
torch.cat((x, y), dim) 按指定维度连接多个张量
&lt;/code>&lt;/pre>
&lt;h3 id="线性代数">线性代数
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 1. 矩阵乘法&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([[&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">]])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">B&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([[&lt;/span>&lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">8&lt;/span>&lt;span class="p">]])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">C&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">mm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">A&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">B&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 矩阵乘法&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 2. 按元素乘法&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">a&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">b&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">c&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">a&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">b&lt;/span> &lt;span class="c1"># 按元素乘法(hadamard积)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 3. 向量点积(等同于按元素乘法后求和)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">a&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">b&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">c&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dot&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">a&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">b&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 向量点积,即torch.sum(a*b)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 4. 矩阵求逆&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([[&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">]])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A_inv&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">inverse&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">A&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 矩阵求逆&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 5. 矩阵转置&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">AT&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">A&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">T&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 矩阵转置&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 6. 矩阵降维(沿指定轴sum或mean)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">arange&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">20&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">float32&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">reshape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">mean&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">axis&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">A&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">axis&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">A&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 7. 非降维求和&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">sum_A&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">A&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">axis&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">keepdims&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 可利用广播机制&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span>&lt;span class="o">/&lt;/span>&lt;span class="n">sum_A&lt;/span>&lt;span class="c1"># 获得每一行间独立的概率分布&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cumsum&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">axis&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 沿某个轴的累计总和,不会降维&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 8. 范数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">A&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">p&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 向量的L2范数,即向量元素平方和的平方根&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">A&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">p&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 向量的L1范数,即向量元素绝对值之和&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">abs&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">u&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># L1范数的另一种表示形式&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 矩阵的Frobenius范数(矩阵元素平方和的平方根，类似于向量的L2范数)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">arange&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">36&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">float32&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">reshape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">9&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="导数和梯度">导数和梯度
&lt;/h3>&lt;h4 id="画图">画图
&lt;/h4>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 见chapter0 calculus.py&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">numpy&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">np&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">matplotlib.pyplot&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">plt&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 设置x,y轴的数值&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">linspace&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">15&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">100&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">y1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sin&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">y2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cos&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 在当前绘图对象中画图（x轴,y轴,给所绘制的曲线的名字，画线颜色，画线宽度）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">plt&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">plot&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">y1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">label&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;$sin(x)$&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">color&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;blue&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">linewidth&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">plt&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">plot&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">y2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">label&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;$cos(x)$&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">color&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;red&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">linewidth&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># X和Y坐标轴的表示&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">plt&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">xlabel&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;Domain&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">plt&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ylabel&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;Range&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 图表的标题&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">plt&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">title&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;sin and cos&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Y轴的范围&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">plt&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ylim&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">-&lt;/span>&lt;span class="mf">1.5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">1.5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 显示图示&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">plt&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">legend&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 显示图&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">plt&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">show&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h4 id="自动求导">自动求导
&lt;/h4>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;span class="lnt">38
&lt;/span>&lt;span class="lnt">39
&lt;/span>&lt;span class="lnt">40
&lt;/span>&lt;span class="lnt">41
&lt;/span>&lt;span class="lnt">42
&lt;/span>&lt;span class="lnt">43
&lt;/span>&lt;span class="lnt">44
&lt;/span>&lt;span class="lnt">45
&lt;/span>&lt;span class="lnt">46
&lt;/span>&lt;span class="lnt">47
&lt;/span>&lt;span class="lnt">48
&lt;/span>&lt;span class="lnt">49
&lt;/span>&lt;span class="lnt">50
&lt;/span>&lt;span class="lnt">51
&lt;/span>&lt;span class="lnt">52
&lt;/span>&lt;span class="lnt">53
&lt;/span>&lt;span class="lnt">54
&lt;/span>&lt;span class="lnt">55
&lt;/span>&lt;span class="lnt">56
&lt;/span>&lt;span class="lnt">57
&lt;/span>&lt;span class="lnt">58
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">arange&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mf">4.0&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 数据类型需要为float,才可微分&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">requires_grad_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 等价于x=torch.arange(4.0,requires_grad=True)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">y&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">2&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dot&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 2*((x_i)**2)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">y&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 自动求导,应为dy/dx=4*x_i&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 输出梯度张量&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">##1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">#在默认情况下,PyTorch会累积梯度,我们需要清除之前的值 &lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zero_&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">y&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># ATTENTION: 这里的y是标量,因为反向传播需要损失函数上的一个特定的值,从而计算梯度.此时y相当于损失函数,需要是一个值(标量),这样才可以进行backwards()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">y&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 输出梯度张量&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">##2 本例只想求偏导数的和，所以传递一个1的梯度是合适的&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zero_&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">y&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="c1"># hadamard积,按元素&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 等价于y.backward(torch.ones(len(x)))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">y&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">##3 分离计算图&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x0&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">arange&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mf">4.0&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">requires_grad&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 本例只想求偏导数的和，所以传递一个1的梯度是合适的&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">y&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x0&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">x0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 只希望使用当前y的值,然后计算z=u(y当前的值)*x0,但不希望获取y的梯度,导致z=x*x*x&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">u&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">y&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">detach&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 保存当前y的值,为常数,梯度不会更新&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">u&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">z&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">u&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">x0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">z&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x0&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x0&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="n">u&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># u是常数,导数即u&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x0&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zero_&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">y&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x0&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x0&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">2&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">x0&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 导数为2*x0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">##4 python控制流中的梯度计算&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 该例子想说明,标量在控制流中(循环,条件分支)进行运算仍会记录梯度的变化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">f&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">a&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">b&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">a&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">while&lt;/span> &lt;span class="n">b&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="mi">1000&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="c1"># 验证循环对梯度的影响&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">b&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">b&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">2&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">b&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="c1"># 验证条件分支对梯度的影响&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">c&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">b&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">else&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">c&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">100&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">b&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">c&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">a&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">requires_grad&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">d&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">f&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">a&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># d=2^n(b在循环内的次数n)*1或100(根据条件分支判断)*a&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">d&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 因此d对a的导数就是d/a&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">a&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">a&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">grad&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="n">d&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">a&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="深度学习计算">深度学习计算
&lt;/h3>&lt;h4 id="gpu相关">GPU相关
&lt;/h4>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 查看是否有GPU&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">is_available&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 设置可见的GPU&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">os&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">os&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">environ&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;CUDA_VISIBLE_DEVICES=1,3&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="c1"># 仅第二、四个GPU可见，引号可加可不加&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">## 也可以在运行前!export CUDA_VISIBLE_DEVICES=0.这样设置后程序中的1卡为实际的3卡&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 设置设备&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">device&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;cpu&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">device&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;cuda:0&amp;#34;&lt;/span> &lt;span class="k">if&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">is_available&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="k">else&lt;/span> &lt;span class="s2">&amp;#34;cpu&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 定义张量&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([[&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">],[&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">6&lt;/span>&lt;span class="p">]],&lt;/span> &lt;span class="n">device&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 张量转移到GPU&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 张量转移到CPU&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;cpu&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h4 id="自定义块">自定义块
&lt;/h4>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 可以在自定义块中定义模型参数,并在forward函数中使用这些参数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 要实现各种层的嵌套,既可以在自定义块的forward函数中嵌套,也可以通过sequential函数来实现&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MLP&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 用模型参数声明层。这里，我们声明两个全连接的层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 调用MLP的父类Module的构造函数来执行必要的初始化。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 这样，在类实例化时也可以指定其他函数参数，例如模型参数params（稍后将介绍）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">hidden&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">20&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">256&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 隐藏层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">out&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">256&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 输出层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 定义模型的前向传播，即如何根据输入X返回所需的模型输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">X&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 注意，这里我们使用ReLU的函数版本，其在nn.functional模块中定义。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">out&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">hidden&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="p">)))&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h4 id="自定义顺序块">自定义顺序块
&lt;/h4>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 对应nn.Sequential函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MySequential&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">*&lt;/span>&lt;span class="n">args&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">idx&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">module&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">args&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 这里，module是Module子类的一个实例。我们把它保存在&amp;#39;Module&amp;#39;类的成员&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 变量_modules中。_module的类型是OrderedDict&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">_modules&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="nb">str&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">idx&lt;/span>&lt;span class="p">)]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">module&lt;/span>&lt;span class="c1"># 该属性存放了各个连接模块的ID&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">X&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># OrderedDict保证了按照成员添加的顺序遍历它们&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">block&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">_modules&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">values&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">X&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">block&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 按顺序传递值&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">X&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MySequential&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">20&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">256&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">256&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 将两个全连接层与一个ReLU层连接在一起&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h4 id="参数管理">参数管理
&lt;/h4>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 参数访问&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">8&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">X&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">())&lt;/span>&lt;span class="c1"># 访问nn.Linear(8, 1)的参数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="p">[(&lt;/span>&lt;span class="n">name&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">param&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="n">name&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">param&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">named_parameters&lt;/span>&lt;span class="p">()])&lt;/span>&lt;span class="c1"># 访问所有参数&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h4 id="参数初始化">参数初始化
&lt;/h4>&lt;h5 id="xavier初始化原理">Xavier初始化原理
&lt;/h5>&lt;p>目的：使得每层的方差相同，从而使得每层的输出方差不变，从而使得每层的输出不受其他层影响。&lt;/p>
&lt;p>全连接层输出为oi，该层输入数量为Nin，输出数量为Nout，输入表示为xj，权重表示为wij(不考虑偏置项).&lt;/p>
&lt;p>权重wij都是从同一分布中独立抽取的，该分布具有零均值和方差σ2。这并不意味着分布必须是高斯的。
现在,让我们假设层xj的输入也具有零均值和方差γ2,它们独立于wij并且彼此独立。&lt;/p>
&lt;p>将输出进行表示：
[o_i = \sum_{j=1}^{N_{in}} w_{ij} x_j]
则其均值为：
[E[o_i] = \sum_{j=1}^{N_{in}} E[w_{ij} x_j] = \sum_{j=1}^{N_{in}} E[w_{ij}] E[x_j] = 0]&lt;/p>
&lt;p>方差为：
[Var[o_i] = \sum_{j=1}^{N_{in}} Var[w_{ij} x_j] = \sum_{j=1}^{N_{in}} E[w_{ij}^2 x_j^2] - 0
= \sum_{j=1}^{N_{in}} E[w_{ij}^2]E[x_j^2]=N_{in}σ^2γ^2]&lt;/p>
&lt;p>保持方差不变的一种方法是设置$N_{in}σ^2 = 1$;&lt;br>
对于反向传播$N_{out}σ^2 = 1$,否则梯度的方差可能会增大.&lt;/p>
&lt;p>&lt;strong>因此,需要满足$\frac{1}{2}(N_{in}+N_{out})σ^2 = 1$&lt;/strong>&lt;/p>
&lt;p>最终确定方差范围后,wij可以从高斯分布或均匀分布中进行采样。
高斯分布采样范围：
[w_{ij} \sim \mathcal{N}(0, \sqrt{\frac{2}{N_{in}+N_{out}}})]
均匀分布采样范围：
[w_{ij} \sim \mathcal{U}(-\sqrt{\frac{6}{N_{in}+N_{out}}}, \sqrt{\frac{6}{N_{in}+N_{out}}})]&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">init&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">xavier_uniform_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 函数会自行计算范围，只需传入要初始化的网络权重&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h5 id="关于netapply">关于net.apply
&lt;/h5>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">init_normal&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="nb">type&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">init&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">normal_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">mean&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">std&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.01&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">init&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zeros_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bias&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;blockquote>
&lt;p>对于net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))使用net.apply(init_normal)和init_normal(net)会有什么区别?&lt;/p>
&lt;/blockquote>
&lt;p>net.apply(init_normal)会递归地遍历模型中的每一层，并对每一层调用init_normal函数;&lt;br>
而init_normal(net)把整个net模型作为参数传递过去,又因为net的类型是sequential,所以会导致类型错误.&lt;/p>
&lt;h4 id="自定义层">自定义层
&lt;/h4>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 不带参数的自定义层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 功能:将输入减去均值,不需要指定网络参数,自适应输入的形状&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">CenteredLayer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">X&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">X&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="n">X&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">mean&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 带参数的自定义层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 创建了一个具有in_units输入单元数和out_units输出单元数的线性层,并通过ReLU后输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MyLinear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">in_units&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">units&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Parameter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">in_units&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">out_units&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bias&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Parameter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">out_units&lt;/span>&lt;span class="p">,))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">X&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">linear&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">matmul&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bias&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">linear&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h4 id="保存和加载模型">保存和加载模型
&lt;/h4>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 保存模型(张量、list、dict等均可)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MLP&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="s1">&amp;#39;net.params&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 加载模型&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">clone_net&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MLP&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">clone_net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load_state_dict&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;net.params&amp;#39;&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="数据加载与处理">数据加载与处理
&lt;/h2>&lt;h3 id="dataset--dataloader">Dataset &amp;amp; Dataloader
&lt;/h3>&lt;p>torch.utils.data.Dataset：数据集的抽象类，需要自定义并实现 &lt;strong>len&lt;/strong>（数据集大小）和 &lt;strong>getitem&lt;/strong>（按索引获取样本）&lt;br>
torch.utils.data.DataLoader：封装 Dataset 的迭代器，提供批处理batch_size、数据打乱shuffle=True、多线程加载num_workers等功能，便于数据输入模型训练&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils.data&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">Dataset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils.data&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">DataLoader&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 自定义数据集&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MyDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Dataset&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">data&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 数据初始化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">data&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">labels&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__len__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 返回数据集大小&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__getitem__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">idx&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 按索引返回数据和标签&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">sample&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">idx&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">label&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">idx&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">label&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 生成示例数据&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">data&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">100&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 100 个样本，每个样本有 5 个特征&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randint&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">100&lt;/span>&lt;span class="p">,))&lt;/span> &lt;span class="c1"># 100 个标签，取值为 0 或 1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 实例化数据集&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dataset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MyDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 创建 DataLoader 实例，batch_size 设置每次加载的样本数量&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dataloader&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dataset&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_workers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 遍历 DataLoader&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">batch_idx&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">batch_data&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_labels&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dataloader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s2">&amp;#34;批次 &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">batch_idx&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;数据:&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_data&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;标签:&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_labels&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">batch_idx&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="c1"># 仅显示前 3 个批次&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">break&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="数据转换">数据转换
&lt;/h3>&lt;p>torchvision.transforms提供了基本的数据预处理（如归一化、大小调整等），还能帮助进行数据增强（如随机裁剪、翻转等），提高模型的泛化能力。&lt;/p>
&lt;p>基础变换操作：&lt;br>
transforms.ToTensor() 将PIL图像或NumPy数组转换为PyTorch张量，并自动将像素值从[0, 255]归一化到 [0, 1]。 transform = transforms.ToTensor()
transforms.Normalize(mean, std) 对图像进行标准化，使数据符合零均值和单位方差。 transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
transforms.Resize(size) 调整图像尺寸，确保输入到网络的图像大小一致。 transform = transforms.Resize((256, 256))
transforms.CenterCrop(size) 从图像中心裁剪指定大小的区域。 transform = transforms.CenterCrop(224)&lt;/p>
&lt;p>数据增强操作：
transforms.RandomHorizontalFlip(p) 随机水平翻转图像。 transform = transforms.RandomHorizontalFlip(p=0.5)
transforms.RandomRotation(degrees) 随机旋转图像。 transform = transforms.RandomRotation(degrees=45)
transforms.ColorJitter(brightness, contrast, saturation, hue) 调整图像的亮度、对比度、饱和度和色调。 transform = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
transforms.RandomCrop(size) 随机裁剪指定大小的区域。 transform = transforms.RandomCrop(224)
transforms.RandomResizedCrop(size) 随机裁剪图像并调整到指定大小。 transform = transforms.RandomResizedCrop(224)&lt;/p>
&lt;p>自定义转换：&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">CustomTransform&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__call__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 这里可以自定义任何变换逻辑&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">2&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">transform&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">CustomTransform&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>组合变换：
transforms.Compose() 将多个变换组合在一起，按照顺序依次应用。
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((256, 256))])&lt;/p>
&lt;h2 id="线性神经网络">线性神经网络
&lt;/h2>&lt;h3 id="线性回归的简洁实现">线性回归的简洁实现
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;span class="lnt">38
&lt;/span>&lt;span class="lnt">39
&lt;/span>&lt;span class="lnt">40
&lt;/span>&lt;span class="lnt">41
&lt;/span>&lt;span class="lnt">42
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">numpy&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">np&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">data&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">d2l&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">torch&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="n">d2l&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">true_w&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">-&lt;/span>&lt;span class="mf">3.4&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">true_b&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mf">4.2&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">features&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">d2l&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">synthetic_data&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">true_w&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">true_b&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1000&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">load_array&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">data_arrays&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">is_train&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">):&lt;/span> &lt;span class="c1">#@save&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;构造一个PyTorch数据迭代器&amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">dataset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">TensorDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="n">data_arrays&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dataset&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">is_train&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">batch_size&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">10&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">data_iter&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">load_array&lt;/span>&lt;span class="p">((&lt;/span>&lt;span class="n">features&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">next&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">iter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">data_iter&lt;/span>&lt;span class="p">))&lt;/span>&lt;span class="c1"># 输出第一个batch的特征和标签,进行验证&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># nn是神经网络的缩写&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">normal_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.01&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># torch中，带下划线的一般指赋值,这里normal_指用均值0,方差0.01的正态分布给w的data属性(即w的值)赋值&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bias&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fill_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">MSELoss&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 损失函数:平方L2范数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">SGD&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.03&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 优化算法和学习率&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">num_epochs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">3&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">epoch&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_epochs&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">X&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">y&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">data_iter&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">l&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">loss&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">,&lt;/span>&lt;span class="n">y&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 前向传播及计算损失函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zero_grad&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 清空累计梯度&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">l&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 反向传播&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">step&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="c1"># 优化参数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">l&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">loss&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">features&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s1">&amp;#39;epoch &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">epoch&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">, loss &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">l&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s1">f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">w&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">net&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;w的估计误差：&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">true_w&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="n">w&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">reshape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">true_w&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">b&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">net&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bias&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;b的估计误差：&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">true_b&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="n">b&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="softmax回归">softmax回归
&lt;/h3>&lt;p>softmax函数返回一个概率分布,其值在0到1之间,且总和为1.因此softmax回归常用于多类别分类问题.&lt;/p>
&lt;h4 id="图像分类数据集">图像分类数据集
&lt;/h4>&lt;p>使用Fashion-MNIST数据集,该数据集包含70,000张图像,分为10个类别,每张图像高和宽均为28像素.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># softmax回归的简洁实现(从零开始实现见工程代码)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">d2l&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">torch&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="n">d2l&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">batch_size&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">256&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">train_iter&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">test_iter&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">d2l&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load_data_fashion_mnist&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">batch_size&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># PyTorch不会隐式地调整输入的形状。因此，在线性层之前定义展平层（flatten），来调整网络输入的形状&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Flatten&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">784&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">## nn.Flatten()将输入的多维张量展平为一维向量.如28*28的图像参数向量将被展平为长784的一维向量&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">init_weights&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="p">):&lt;/span>&lt;span class="c1"># 初始化权重&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="nb">type&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">init&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">normal_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">std&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.01&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">apply&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">init_weights&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CrossEntropyLoss&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">reduction&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;none&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">trainer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">SGD&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">num_epochs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">10&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">d2l&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train_ch6&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">train_iter&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">test_iter&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mf">0.03&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 包里没ch3的trainer了. d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="多层感知机mlp">多层感知机MLP
&lt;/h2>&lt;h3 id="多层感知机的简洁实现">多层感知机的简洁实现
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 网络部分写法&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Flatten&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">784&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">256&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">256&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">init_weights&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="nb">type&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">init&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">normal_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">std&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.01&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">apply&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">init_weights&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="权重衰减">权重衰减
&lt;/h3>&lt;p>权重衰减(weight decay)也被称为L2正则化。使得模型参数不会过大,从而控制复杂度。正则项的权重λ是控制模型复杂度的超参数。&lt;/p>
&lt;p>目标函数 = 损失函数 + 正则项：&lt;br>
[{\rm{L(w, b) + }}\frac{\lambda }{2}{\left| w \right|^2}]
使用L2范数的一个原因是它对权重向量的大分量施加了巨大的惩罚。这使得我们的学习算法偏向于在大量特征上均匀分布权重的模型。在实践中,这可能使它们对单个变量中的观测误差更为稳定。&lt;/p>
&lt;p>&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep_learning/pytorch/wdecay.png"
loading="lazy"
alt="wdecay.png"
>
坐标轴对应w的取值
绿点表示L(w, b)的最优点,坐标轴原点表示L2范数的最小值.因此距离这两个点越远,惩罚越大.
黄点是两者相制衡得到的惩罚函数最小值,即权重衰减的效果.&lt;/p>
&lt;p>更新权重:
[{{\rm{w}}&lt;em>{{\rm{t}} + 1}} \leftarrow (1 - \eta \lambda ){{\rm{w}}&lt;/em>{\rm{t}}} - \frac{\eta }{B}\sum {\frac{{\partial L(w,b)}}{{\partial w}}} ]
注意$(1 - \eta \lambda)$处,表示先对wt做衰减,在进行更新.&lt;/p>
&lt;h3 id="丢弃法dropout">丢弃法(dropout)
&lt;/h3>&lt;p>对每个中间活性值h以暂退概率p由随机变量h′替换。有概率p置零，其余概率扩大1-p倍，从而保证均值不变。
[h&amp;rsquo; = \begin{cases}\frac{h}{1-p}, &amp;amp;\text{以概率 }1-p \ 0, &amp;amp;\text{以概率 }p\end{cases}]
如果通过许多不同的暂退法遮盖后得到的预测结果都是一致的,那么我们可以说网络发挥更稳定。&lt;/p>
&lt;p>代码实现：&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># dropout层，连接在全连接层之后，对输出进行丢弃&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">dropout_layer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">pdropout&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">assert&lt;/span> &lt;span class="mi">0&lt;/span> &lt;span class="o">&amp;lt;=&lt;/span> &lt;span class="n">pdropout&lt;/span> &lt;span class="o">&amp;lt;=&lt;/span> &lt;span class="mi">1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># p=1,所有元素都被丢弃 &lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">pdropout&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zeros_like&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># p=0,所有元素都被保留 &lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">pdropout&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">X&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">mask&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">X&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="n">pdropout&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">float&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">mask&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">X&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mf">1.0&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="n">pdropout&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 简洁调用&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">pdropout&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="卷积神经网络cnn">卷积神经网络CNN
&lt;/h2>&lt;p>CNN专门用于处理具有网格状拓扑结构数据（如图像）&lt;br>
部分主要流程：(卷积-池化)*N-展平-分类&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">SimpleCNN&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">SimpleCNN&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 定义卷积层：输入1通道，输出32通道，卷积核大小3x3&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">padding&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 定义卷积层：输入32通道，输出64通道&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">padding&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 定义全连接层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">64&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">7&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">128&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 输入大小 = 特征图大小 * 通道数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">128&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 10 个类别&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 第一层卷积 + ReLU&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max_pool2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 最大池化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 第二层卷积 + ReLU&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max_pool2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 最大池化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">64&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">7&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">7&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 展平操作&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 全连接层 + ReLU&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 全连接层输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="循环神经网络rnn">循环神经网络RNN
&lt;/h2>&lt;p>RNN专门用于处理序列数据，能够捕捉时间序列或有序数据的动态信息，如文本、时间序列或音频&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">SimpleRNN&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">input_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">hidden_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">output_size&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">SimpleRNN&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 定义 RNN 层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rnn&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">RNN&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">input_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">hidden_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_first&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 定义全连接层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">hidden_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">output_size&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># x: (batch_size, seq_len, input_size)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">out&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rnn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># out: (batch_size, seq_len, hidden_size)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 取序列最后一个时间步的输出作为模型的输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">out&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">out&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">:]&lt;/span> &lt;span class="c1"># (batch_size, hidden_size)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">out&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">out&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 全连接层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">out&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="transformer">Transformer
&lt;/h2>&lt;p>实际使用时，可以直接调用nn.embedding、nn.Transformer、nn.positional_encoding层构成模型，无需自己编写&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">TransformerModel&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">input_dim&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model_dim&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_layers&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">output_dim&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">TransformerModel&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">embedding&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Embedding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">input_dim&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model_dim&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">positional_encoding&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Parameter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zeros&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1000&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model_dim&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 假设序列长度最大为1000&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transformer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Transformer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">model_dim&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">nhead&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">num_heads&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_encoder_layers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">num_layers&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model_dim&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">output_dim&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">src&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">src_seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt_seq_length&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">src&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tgt&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">src&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">embedding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">src&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">positional_encoding&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="p">:&lt;/span>&lt;span class="n">src_seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">:]&lt;/span>&lt;span class="c1"># 取实际序列长度的部分&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tgt&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">embedding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tgt&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">positional_encoding&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="p">:&lt;/span>&lt;span class="n">tgt_seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">:]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transformer_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transformer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">src&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">transformer_output&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">output&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>下面为自己实现各模块：&lt;/p>
&lt;h3 id="注意力机制">注意力机制
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;span class="lnt">38
&lt;/span>&lt;span class="lnt">39
&lt;/span>&lt;span class="lnt">40
&lt;/span>&lt;span class="lnt">41
&lt;/span>&lt;span class="lnt">42
&lt;/span>&lt;span class="lnt">43
&lt;/span>&lt;span class="lnt">44
&lt;/span>&lt;span class="lnt">45
&lt;/span>&lt;span class="lnt">46
&lt;/span>&lt;span class="lnt">47
&lt;/span>&lt;span class="lnt">48
&lt;/span>&lt;span class="lnt">49
&lt;/span>&lt;span class="lnt">50
&lt;/span>&lt;span class="lnt">51
&lt;/span>&lt;span class="lnt">52
&lt;/span>&lt;span class="lnt">53
&lt;/span>&lt;span class="lnt">54
&lt;/span>&lt;span class="lnt">55
&lt;/span>&lt;span class="lnt">56
&lt;/span>&lt;span class="lnt">57
&lt;/span>&lt;span class="lnt">58
&lt;/span>&lt;span class="lnt">59
&lt;/span>&lt;span class="lnt">60
&lt;/span>&lt;span class="lnt">61
&lt;/span>&lt;span class="lnt">62
&lt;/span>&lt;span class="lnt">63
&lt;/span>&lt;span class="lnt">64
&lt;/span>&lt;span class="lnt">65
&lt;/span>&lt;span class="lnt">66
&lt;/span>&lt;span class="lnt">67
&lt;/span>&lt;span class="lnt">68
&lt;/span>&lt;span class="lnt">69
&lt;/span>&lt;span class="lnt">70
&lt;/span>&lt;span class="lnt">71
&lt;/span>&lt;span class="lnt">72
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MultiHeadAttention&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">MultiHeadAttention&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">assert&lt;/span> &lt;span class="n">d_model&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="n">num_heads&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s2">&amp;#34;d_model必须能被num_heads整除&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">d_model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">d_model&lt;/span> &lt;span class="c1"># 模型维度（如512）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">num_heads&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">num_heads&lt;/span> &lt;span class="c1"># 注意力头数（如8）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">d_k&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">d_model&lt;/span> &lt;span class="o">//&lt;/span> &lt;span class="n">num_heads&lt;/span> &lt;span class="c1"># 每个头的维度（如64）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 定义线性变换层（无需偏置）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">W_q&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 查询变换&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">W_k&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 键变换&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">W_v&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 值变换&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">W_o&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 输出变换&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">scaled_dot_product_attention&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Q&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">K&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">V&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">mask&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">None&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 计算缩放点积注意力
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 输入形状：
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> Q: (batch_size, num_heads, seq_length, d_k)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> K, V: 同Q
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 输出形状： (batch_size, num_heads, seq_length, d_k)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> &amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 计算注意力分数（Q和K的点积）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">attn_scores&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">matmul&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Q&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">K&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transpose&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">-&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">math&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sqrt&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">d_k&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 应用掩码（如填充掩码或未来信息掩码）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">mask&lt;/span> &lt;span class="ow">is&lt;/span> &lt;span class="ow">not&lt;/span> &lt;span class="kc">None&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">attn_scores&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">attn_scores&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">masked_fill&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">mask&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">-&lt;/span>&lt;span class="mf">1e9&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 计算注意力权重（softmax归一化）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">attn_probs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">softmax&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">attn_scores&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dim&lt;/span>&lt;span class="o">=-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 对值向量加权求和&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">matmul&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">attn_probs&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">V&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">output&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">split_heads&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 将输入张量分割为多个头
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 输入形状: (batch_size, seq_length, d_model)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 输出形状: (batch_size, num_heads, seq_length, d_k)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> &amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">num_heads&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">d_k&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transpose&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">combine_heads&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 将多个头的输出合并回原始形状
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 输入形状: (batch_size, num_heads, seq_length, d_k)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 输出形状: (batch_size, seq_length, d_model)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> &amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_k&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transpose&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">contiguous&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Q&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">K&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">V&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">mask&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">None&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 前向传播
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 输入形状: Q/K/V: (batch_size, seq_length, d_model)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> 输出形状: (batch_size, seq_length, d_model)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> &amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 线性变换并分割多头&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Q&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">split_heads&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">W_q&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Q&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># (batch, heads, seq_len, d_k)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">K&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">split_heads&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">W_k&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">K&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">V&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">split_heads&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">W_v&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">V&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 计算注意力&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">attn_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">scaled_dot_product_attention&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Q&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">K&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">V&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">mask&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 合并多头并输出变换&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">W_o&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">combine_heads&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">attn_output&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">output&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="位置编码">位置编码
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">PositionWiseFeedForward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_ff&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">PositionWiseFeedForward&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_ff&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 第一层全连接&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_ff&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 第二层全连接&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 激活函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 前馈网络的计算&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">PositionalEncoding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">max_seq_length&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">PositionalEncoding&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">pe&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zeros&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">max_seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 初始化位置编码矩阵&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">position&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">arange&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">max_seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">float&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">unsqueeze&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">div_term&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">exp&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">arange&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">float&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="o">-&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">math&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">log&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mf">10000.0&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">pe&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">::&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sin&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">position&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">div_term&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 偶数位置使用正弦函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">pe&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">::&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cos&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">position&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">div_term&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 奇数位置使用余弦函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">register_buffer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;pe&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">pe&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">unsqueeze&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 注册为缓冲区&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 将位置编码添加到输入中&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">pe&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="p">:&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)]&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="encoder">Encoder
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">EncoderLayer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_ff&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dropout&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">EncoderLayer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">self_attn&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MultiHeadAttention&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 自注意力机制&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">feed_forward&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">PositionWiseFeedForward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_ff&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 前馈网络&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">LayerNorm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 层归一化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">LayerNorm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># Dropout&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">mask&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 自注意力机制&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">attn_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">self_attn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">mask&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">attn_output&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 残差连接和层归一化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 前馈网络&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">ff_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">feed_forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">ff_output&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 残差连接和层归一化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="decoder">Decoder
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">DecoderLayer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_ff&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dropout&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">DecoderLayer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">self_attn&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MultiHeadAttention&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 自注意力机制&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cross_attn&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MultiHeadAttention&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 交叉注意力机制&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">feed_forward&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">PositionWiseFeedForward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_ff&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 前馈网络&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">LayerNorm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 层归一化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">LayerNorm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm3&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">LayerNorm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># Dropout&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">enc_output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">src_mask&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt_mask&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 自注意力机制&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">attn_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">self_attn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt_mask&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">attn_output&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 残差连接和层归一化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 交叉注意力机制&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">attn_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cross_attn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">enc_output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">enc_output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">src_mask&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">attn_output&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 残差连接和层归一化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 前馈网络&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">ff_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">feed_forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">norm3&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">ff_output&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 残差连接和层归一化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="整体架构">整体架构
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;span class="lnt">38
&lt;/span>&lt;span class="lnt">39
&lt;/span>&lt;span class="lnt">40
&lt;/span>&lt;span class="lnt">41
&lt;/span>&lt;span class="lnt">42
&lt;/span>&lt;span class="lnt">43
&lt;/span>&lt;span class="lnt">44
&lt;/span>&lt;span class="lnt">45
&lt;/span>&lt;span class="lnt">46
&lt;/span>&lt;span class="lnt">47
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">Transformer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">src_vocab_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt_vocab_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_layers&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_ff&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">max_seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dropout&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Transformer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">encoder_embedding&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Embedding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">src_vocab_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 编码器词嵌入&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">decoder_embedding&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Embedding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tgt_vocab_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_model&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 解码器词嵌入&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">positional_encoding&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">PositionalEncoding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">max_seq_length&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 位置编码&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 编码器和解码器层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">encoder_layers&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ModuleList&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="n">EncoderLayer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_ff&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dropout&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="n">_&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_layers&lt;/span>&lt;span class="p">)])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">decoder_layers&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ModuleList&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="n">DecoderLayer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">d_ff&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dropout&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="n">_&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_layers&lt;/span>&lt;span class="p">)])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt_vocab_size&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 最终的全连接层&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># Dropout&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">generate_mask&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">src&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 源掩码：屏蔽填充符（假设填充符索引为0）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 形状：(batch_size, 1, 1, seq_length)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">src_mask&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">src&lt;/span> &lt;span class="o">!=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">unsqueeze&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">unsqueeze&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 目标掩码：屏蔽填充符和未来信息&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 形状：(batch_size, 1, seq_length, 1)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tgt_mask&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">tgt&lt;/span> &lt;span class="o">!=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">unsqueeze&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">unsqueeze&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">seq_length&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tgt&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 生成上三角矩阵掩码，防止解码时看到未来信息&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nopeak_mask&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">triu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ones&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">seq_length&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">seq_length&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">diagonal&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">))&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bool&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tgt_mask&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tgt_mask&lt;/span> &lt;span class="o">&amp;amp;&lt;/span> &lt;span class="n">nopeak_mask&lt;/span> &lt;span class="c1"># 合并填充掩码和未来信息掩码&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">src_mask&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt_mask&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">src&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 生成掩码&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">src_mask&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt_mask&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">generate_mask&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">src&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 编码器部分&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">src_embedded&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">positional_encoding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">encoder_embedding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">src&lt;/span>&lt;span class="p">)))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">enc_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">src_embedded&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">enc_layer&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">encoder_layers&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">enc_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">enc_layer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">enc_output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">src_mask&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 解码器部分&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tgt_embedded&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">positional_encoding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">decoder_embedding&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tgt&lt;/span>&lt;span class="p">)))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">dec_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tgt_embedded&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">dec_layer&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">decoder_layers&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">dec_output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">dec_layer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dec_output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">enc_output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">src_mask&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tgt_mask&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 最终输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dec_output&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">output&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="优化算法">优化算法
&lt;/h2>&lt;h2 id="计算机视觉">计算机视觉
&lt;/h2></description></item><item><title>激活函数</title><link>https://hych0317.github.io/hugoweb_auto/p/%E6%BF%80%E6%B4%BB%E5%87%BD%E6%95%B0/</link><pubDate>Sun, 20 Oct 2024 21:01:22 +0800</pubDate><guid>https://hych0317.github.io/hugoweb_auto/p/%E6%BF%80%E6%B4%BB%E5%87%BD%E6%95%B0/</guid><description>&lt;h2 id="激活函数">激活函数
&lt;/h2>&lt;p>引入激活函数是为了增加神经网络模型的非线性。若没有激活函数的每层都相当于矩阵相乘。没有激活函数的神经网络叠加了若干层之后，还是一个线性变换，与单层感知机无异。
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/main.png"
loading="lazy"
alt="激活函数示意图"
>&lt;/p>
&lt;h3 id="分类">分类
&lt;/h3>&lt;ul>
&lt;li>饱和激活函数： sigmoid、 tanh&amp;hellip;&lt;/li>
&lt;li>非饱和激活函数: ReLU 、Leaky Relu 、ELU、PReLU、RReLU&amp;hellip;&lt;/li>
&lt;/ul>
&lt;p>当x趋向于正无穷时，函数的导数趋近于0，此时称为右饱和。
当x趋向于负无穷时，函数的导数趋近于0，此时称为左饱和。&lt;/p>
&lt;blockquote>
&lt;p>当一个函数既满足右饱和，又满足左饱和，则称为饱和函数，否则称为非饱和函数。&lt;/p>
&lt;/blockquote>
&lt;p>非饱和激活函数的优势在于两点：&lt;/p>
&lt;ol>
&lt;li>
&lt;p>非饱和激活函数能解决深度神经网络（层数非常多）带来的梯度消失问题&lt;/p>
&lt;/li>
&lt;li>
&lt;p>使用非饱和激活函数能加快收敛速度&lt;/p>
&lt;/li>
&lt;/ol>
&lt;p>PS:&lt;/p>
&lt;ul>
&lt;li>
&lt;p>梯度饱和导致梯度消失：梯度饱和是梯度消失的主要原因之一。当激活函数的输出趋于饱和值时，其导数很小，导致梯度变小。随着网络层数的增加，这种小梯度不断积累，从而导致梯度消失。&lt;/p>
&lt;/li>
&lt;li>
&lt;p>梯度消失与梯度爆炸：这两者都是在深度网络中累乘梯度过程中产生的，但它们是两个极端情况。梯度消失是梯度在传播过程中不断缩小，而梯度爆炸是梯度不断放大。&lt;/p>
&lt;/li>
&lt;/ul>
&lt;h3 id="常见激活函数">常见激活函数
&lt;/h3>&lt;h4 id="sigmoid函数">sigmoid函数
&lt;/h4>&lt;p>$$\sigma(x) = \frac{1}{1+e^{-x}}$$
其导数为：
$$\sigma&amp;rsquo;(x) = \frac{\exp(-x)}{(1+e^{-x})^2} = \sigma(x)(1-\sigma(x))$$
图像：&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/sigmoid.png"
loading="lazy"
alt="sigmoid函数图像"
>&lt;/p>
&lt;p>优点：&lt;/p>
&lt;ol>
&lt;li>Sigmoid 函数的输出范围是 0 到 1。预测值非常接近0/1，可以用于表示二分类的类别或者用于表示置信度。&lt;/li>
&lt;li>梯度平滑，便于求导，也防止模型训练过程中出现突变的梯度&lt;/li>
&lt;/ol>
&lt;p>缺点：&lt;/p>
&lt;ol>
&lt;li>Sigmoid 函数的输出不是以 0 为中心的，可能导致模型输出的均值偏离 0。&lt;/li>
&lt;li>Sigmoid 函数在计算过程中容易出现梯度消失的问题。&lt;/li>
&lt;li>Sigmoid 函数需要计算指数函数，计算量大，效率低。&lt;/li>
&lt;/ol>
&lt;h4 id="softmax函数归一化指数函数">softmax函数（归一化指数函数）
&lt;/h4>&lt;p>$$softmax(x_i) = \frac{\exp(x_i)}{\sum_{j}\exp(x_j)}$$&lt;/p>
&lt;p>特点：&lt;/p>
&lt;ol>
&lt;li>Softmax函数常用于&lt;strong>多类分类问题&lt;/strong>的输出层激活函数。
对于长度为K的任意实向量，Softmax函数可以将其压缩为长度为K，值在[0,1]范围内，并且向量中元素的总和为1的实向量（即概率分布向量）。其值反映了该向量中各个元素的概率。&lt;/li>
&lt;/ol>
&lt;p>输出层使用例：&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/softmax_exp.png"
loading="lazy"
alt="softmax函数图像"
>&lt;/p>
&lt;h4 id="tanh函数">tanh函数
&lt;/h4>&lt;p>$$tanh(x) = \frac{\exp(x)-\exp(-x)}{\exp(x)+\exp(-x)}$$
其导数为：
$$tanh&amp;rsquo;(x) = 1-tanh^2(x)$$
实际上，tanh函数就是将sigmoid函数的输出拉伸到-1到1之间。&lt;br>
[tanh(x) = 2\sigma(2x)-1]&lt;/p>
&lt;p>图像：&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/tanh.png"
loading="lazy"
alt="tanh函数图像"
>&lt;/p>
&lt;p>优点：&lt;/p>
&lt;ol>
&lt;li>函数以 0 为中心，输出范围是 -1 到 1。&lt;/li>
&lt;li>tanh 函数的导数在 0 处梯度为1，因此可以减轻梯度饱和问题。&lt;/li>
&lt;/ol>
&lt;p>缺点：&lt;/p>
&lt;ol>
&lt;li>仍存在梯度饱和的问题。&lt;/li>
&lt;li>仍是指数运算。&lt;/li>
&lt;/ol>
&lt;h4 id="softsign函数更平滑的tanh函数">softsign函数（更平滑的tanh函数）
&lt;/h4>&lt;p>$$softsign(x) = \frac{x}{1+|x|}$$
其导数为：
$$softsign&amp;rsquo;(x) = \frac{1}{1+|x|^2}$$&lt;/p>
&lt;p>图像：&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/softsign.png"
loading="lazy"
alt="softsign函数图像"
>&lt;/p>
&lt;p>优点：&lt;/p>
&lt;ol>
&lt;li>曲线更平坦、梯度下降更慢，表明它可以更高效地学习&lt;/li>
&lt;li>可以更好的减轻梯度饱和问题。&lt;/li>
&lt;/ol>
&lt;p>缺点：&lt;/p>
&lt;ol>
&lt;li>计算更麻烦&lt;/li>
&lt;/ol>
&lt;h4 id="relu函数线性整流函数">ReLU函数（线性整流函数）
&lt;/h4>&lt;p>$$ReLU(x) = max(0,x)$$&lt;/p>
&lt;p>图像：&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/relu.png"
loading="lazy"
alt="ReLU函数图像"
>&lt;/p>
&lt;p>优点：&lt;/p>
&lt;ol>
&lt;li>计算简单，速度快。&lt;/li>
&lt;li>输入为正时，不存在梯度饱和。&lt;/li>
&lt;/ol>
&lt;p>缺点：&lt;/p>
&lt;ol>
&lt;li>输出不是以 0 为中心，可能导致模型输出的均值偏离0。&lt;/li>
&lt;li>Dead ReLU问题：当神经元的输入为负数，则该神经元的梯度为0，输出恒为0，不再对输入数据有所响应，导致其后参数不再更新。&lt;/li>
&lt;/ol>
&lt;p>&lt;strong>需要注意的是，虽然Leaky ReLU和ELU函数都能解决ReLU函数的死亡问题，但实践中并未表明他们比ReLU函数效果更好。&lt;/strong>&lt;/p>
&lt;h4 id="softplus函数relu的平滑版本">softplus函数（ReLU的平滑版本）
&lt;/h4>&lt;p>$$softplus(x) = \ln(1+e^x)$$&lt;/p>
&lt;p>图像：&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/softplus.png"
loading="lazy"
alt="softplus函数图像"
>&lt;/p>
&lt;p>特点：&lt;/p>
&lt;ol>
&lt;li>Softplus 是一个平滑函数，在所有点上都可微。这意味着它在反向传播时不会出现像 ReLU 那样的导数不连续问题。&lt;/li>
&lt;li>Softplus 在整个定义域上都有非零梯度，因此在反向传播中不会出现梯度消失的问题。&lt;/li>
&lt;li>计算复杂。&lt;/li>
&lt;/ol>
&lt;h4 id="leaky-relu函数">Leaky ReLU函数
&lt;/h4>&lt;p>$$Leaky\ ReLU(x) = max(\alpha x,x)$$&lt;/p>
&lt;p>图像：&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/Leaky_ReLU.png"
loading="lazy"
alt="Leaky ReLU函数图像"
>&lt;/p>
&lt;p>优点：&lt;/p>
&lt;ol>
&lt;li>解决了ReLU函数的死亡问题。&lt;/li>
&lt;li>同ReLU函数一样，输入为正时，不存在梯度饱和。&lt;/li>
&lt;/ol>
&lt;p>缺点：&lt;/p>
&lt;ol>
&lt;li>$\alpha$值需要人为设定，不易调参，一般取0.01。&lt;/li>
&lt;li>有些近似线性，导致在复杂分类中效果不好。&lt;/li>
&lt;/ol>
&lt;h4 id="prelu函数">PReLU函数
&lt;/h4>&lt;p>$$PReLU(x) = max(\alpha x,x)$$
与Leaky ReLU激活函数不同的是，PRelu激活函数负半轴的斜率参数α 是通过学习得到的，而不是手动设置的恒定值。&lt;/p>
&lt;p>各性质类似于Leaky ReLU激活函数。&lt;/p>
&lt;h4 id="elu函数">ELU函数
&lt;/h4>&lt;p>$$ELU(x) = \left{ \begin{array}{ll} \alpha(e^x-1) &amp;amp; x&amp;lt;0 \ x &amp;amp; x\geq0 \end{array} \right. $$&lt;/p>
&lt;p>图像：&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/ELU.png"
loading="lazy"
alt="ELU函数图像"
>&lt;/p>
&lt;p>优点：&lt;/p>
&lt;ol>
&lt;li>ELU试图将激活函数的输出均值接近于零，使正常梯度更接近于单位自然梯度，从而加快学习速度。&lt;/li>
&lt;li>ELU 在较小的输入下会饱和至负值，从而减少前向传播的变异和信息。&lt;/li>
&lt;li>输出是以 0 为中心，输出范围是 -1 到 1。&lt;/li>
&lt;/ol>
&lt;p>缺点：&lt;/p>
&lt;ol>
&lt;li>计算指数函数，效率低。&lt;/li>
&lt;/ol>
&lt;h4 id="selu函数">SELU函数
&lt;/h4>&lt;p>$$SELU(x) = \lambda\left{ \begin{array}{ll} \alpha(e^x-1) &amp;amp; x&amp;lt;0 \ x &amp;amp; x\geq0 \end{array} \right. $$&lt;/p>
&lt;p>图像：&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/SELU.png"
loading="lazy"
alt="SELU函数图像"
>&lt;/p>
&lt;p>特点：&lt;/p>
&lt;ol>
&lt;li>SELU 允许构建一个映射 g，其性质能够实现 SNN(自归一化神经网络)。&lt;/li>
&lt;li>SELU函数通过调整均值和方差来实现内部的归一化，这种内部归一化比外部归一化更快，这使得网络能够更快得收敛&lt;/li>
&lt;/ol>
&lt;p>PS:
SNN网络激活函数的要求：&lt;/p>
&lt;ol>
&lt;li>负值和正值，以便控制均值；&lt;/li>
&lt;li>饱和区域（导数趋近于零），以便抑制更低层中较大的方差；&lt;/li>
&lt;li>大于 1 的斜率，以便在更低层中的方差过小时增大方差；&lt;/li>
&lt;li>连续曲线。&lt;/li>
&lt;/ol>
&lt;h4 id="swish函数">swish函数
&lt;/h4>&lt;p>$$swish(x) = x\sigma(x)$$&lt;/p>
&lt;p>图像：&lt;br>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/activation_function/swish.png"
loading="lazy"
alt="swish函数图像"
>&lt;/p>
&lt;p>特点：&lt;/p>
&lt;ol>
&lt;li>与sigmoid函数类似，但更平滑，在优化和泛化中起了重要作用。&lt;/li>
&lt;li>其无界性有助于防止慢速训练期间，梯度逐渐接近 0 并导致饱和。&lt;/li>
&lt;/ol></description></item><item><title>LLM_flow</title><link>https://hych0317.github.io/hugoweb_auto/p/llm_flow/</link><pubDate>Wed, 25 Sep 2024 22:49:25 +0800</pubDate><guid>https://hych0317.github.io/hugoweb_auto/p/llm_flow/</guid><description>&lt;h2 id="前置">前置
&lt;/h2>&lt;h3 id="服务器使用指南">服务器使用指南
&lt;/h3>&lt;p>选择GPU的几种方式&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 设置可见GPU&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">os&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">os&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">environ&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s2">&amp;#34;CUDA_VISIBLE_DEVICES&amp;#34;&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="s2">&amp;#34;0，1&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">#n为GPU编号，从0开始。需要在import torch前设置&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">device&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;cuda:0&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="c1"># 只能指定单卡&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 指定使用GPU 0, 1和2（不设置device_ids或令其=None，则默认使用所有GPU）&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">DataParallel&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">device_ids&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="pytorch安装">pytorch安装
&lt;/h3>&lt;p>使用pip安装比conda更靠谱，指令见(&lt;a class="link" href="https://pytorch.org/get-started/locally/" target="_blank" rel="noopener"
>https://pytorch.org/get-started/locally/&lt;/a>)
cuda版本等细节见土堆教程。
(pip临时代理： &amp;ndash;proxy=http://)&lt;/p>
&lt;h3 id="下载预训练模型">下载预训练模型
&lt;/h3>&lt;p>&lt;strong>注意：需要使用git lfs clone以下载模型参数文件。&lt;/strong>&lt;/p>
&lt;ul>
&lt;li>
&lt;p>git clone&lt;br>
git clone url &amp;ndash;depth=1:只下载最近一次commit&lt;br>
(-c http.proxy=&amp;quot;http://127.0.0.1:7890&amp;quot; # 临时代理)&lt;br>
lfs下载大文件时，其文件大小会增加，但进度条百分比和网速会卡住，当下完完整的一个大文件后才更新进度，耐心等待即可。&lt;/p>
&lt;/li>
&lt;li>
&lt;p>huggingface下载：&lt;/p>
&lt;p>1.(推荐)&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;span class="lnt">8
&lt;/span>&lt;span class="lnt">9
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-bash" data-lang="bash">&lt;span class="line">&lt;span class="cl">sudo apt install aria2
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">wget https://hf-mirror.com/hfd/hfd.sh
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">chmod +x hfd.sh
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">export&lt;/span> &lt;span class="nv">HF_ENDPOINT&lt;/span>&lt;span class="o">=&lt;/span>https://hf-mirror.com
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 下载模型(下载速度慢可以取消下载后断点重连)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">./hfd.sh &amp;lt;model_name&amp;gt; --tool aria2c -x &lt;span class="m">4&lt;/span> &lt;span class="o">[&lt;/span>可选&lt;span class="o">]&lt;/span>--hf_username &amp;lt;username&amp;gt; --hf_token &amp;lt;apikey&amp;gt;
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 代理：--all-proxy=http://&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 下载数据集&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">./hfd.sh &amp;lt;dataset_name&amp;gt; --dataset --tool aria2c -x &lt;span class="m">4&lt;/span> &lt;span class="o">[&lt;/span>可选&lt;span class="o">]&lt;/span>--hf_username &amp;lt;username&amp;gt; --hf_token &amp;lt;apikey&amp;gt;
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;ol start="2">
&lt;li>&lt;/li>
&lt;/ol>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;span class="lnt">8
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-bash" data-lang="bash">&lt;span class="line">&lt;span class="cl">pip install -U huggingface_hub
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 设置环境变量为镜像源(建议写入/.bashrc)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">export&lt;/span> &lt;span class="nv">HF_ENDPOINT&lt;/span>&lt;span class="o">=&lt;/span>https://hf-mirror.com
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 下载模型&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">huggingface-cli download --resume-download &amp;lt;model_name&amp;gt; --local-dir &amp;lt;local_dir_name&amp;gt;
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 下载数据集&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">huggingface-cli download --repo-type dataset --resume-download &amp;lt;dataset_name&amp;gt; --local-dir &amp;lt;local_dir_name&amp;gt;
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">## 可以添加--local-dir-use-symlinks False禁用软链接，下载路径下所见即所得&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;ol start="3">
&lt;li>选择clone repository（三个点展开），使用提供的git clone命令下载到本地。&lt;/li>
&lt;/ol>
&lt;/li>
&lt;li>
&lt;p>魔搭社区下载：&lt;br>
参考见上。&lt;/p>
&lt;/li>
&lt;/ul>
&lt;h2 id="推理inference">推理(inference)
&lt;/h2>&lt;h3 id="概述">概述
&lt;/h3>&lt;p>推理(inference)是指模型对输入数据进行预测，得到模型输出结果。&lt;br>
推理过程可以分为三个步骤：&lt;/p>
&lt;ol>
&lt;li>加载模型参数：加载模型参数，包括模型结构和模型参数。&lt;/li>
&lt;li>输入数据预处理：对输入数据进行预处理，如tokenizing、padding等。&lt;/li>
&lt;li>模型推理：使用模型进行推理，得到模型输出结果。&lt;/li>
&lt;/ol>
&lt;h2 id="微调finetune">微调(finetune)
&lt;/h2>&lt;h3 id="介绍">介绍
&lt;/h3>&lt;h4 id="序言">序言
&lt;/h4>&lt;p>用好大模型的第一个层次，是掌握提示词工程(Prompt Engineering)，
用好大模型的第二个层次，是大模型的微调(Fine Tuning)&lt;/p>
&lt;p>Prompt Engineering 的方式会把Prompt搞得很长
微调通过自有数据，优化模型在特定任务上的性能，减少幻觉。&lt;/p>
&lt;h4 id="技术路线">技术路线
&lt;/h4>&lt;p>从参数规模的角度，大模型的微调分成两条技术路线：&lt;/p>
&lt;ul>
&lt;li>全量微调FFT(Full Fine Tuning)：对全量的参数，进行全量的训练。&lt;/li>
&lt;li>参数高效微调PEFT(Parameter-Efficient Fine Tuning)：只对部分的参数进行训练，如Lora。&lt;/li>
&lt;/ul>
&lt;p>从训练的方法的角度&lt;/p>
&lt;ul>
&lt;li>监督式微调SFT(Supervised Fine Tuning)，主要是用人工标注的数据，用传统机器学习中监督学习的方法，对大模型进行微调；&lt;/li>
&lt;li>基于人类反馈的强化学习微调RLHF(Reinforcement Learning with Human Feedback)，这个方案的主要特点是把人类的反馈，通过强化学习的方式，引入到对大模型的微调中去，让大模型生成的结果，更加符合人类的一些期望；&lt;/li>
&lt;li>基于AI反馈的强化学习微调RLAIF(Reinforcement Learning with AI Feedback)，这个原理大致跟RLHF类似，但是反馈的来源是AI。&lt;/li>
&lt;/ul>
&lt;h3 id="微调流程">微调流程
&lt;/h3>&lt;p>&lt;a class="link" href="https://www.bilibili.com/video/BV1dr421w7J5" target="_blank" rel="noopener"
>Lora参考流程&lt;/a>&lt;/p>
&lt;h4 id="数据集">数据集
&lt;/h4>&lt;p>instruction字段通常用于描述任务类型或给出指令&lt;br>
input字段包含模型需要处理的文本数据&lt;br>
output字段则包含对应输入的正确答案或期望输出&lt;/p>
&lt;blockquote>
&lt;p>常用中文微调数据集可能包括：
中文问答数据集（如CMRC 2018、DRCD等），用于训练问答系统。
中文情感分析数据集（如ChnSentiCorp、Fudan News等），用于训练情感分类模型。
中文文本相似度数据集（如LCQMC、BQ Corpus等），用于训练句子对匹配和相似度判断任务。
中文摘要生成数据集（如LCSTS、NLPCC等），用于训练文本摘要生成模型。
中文对话数据集（如LCCC、ECDT等），用于训练聊天机器人或对话系统。&lt;/p>
&lt;/blockquote>
&lt;h4 id="训练过程">训练过程
&lt;/h4>&lt;h4 id="评估与迭代">评估与迭代
&lt;/h4>&lt;h4 id="lora训练示例">Lora训练示例
&lt;/h4>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;span class="lnt">38
&lt;/span>&lt;span class="lnt">39
&lt;/span>&lt;span class="lnt">40
&lt;/span>&lt;span class="lnt">41
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span> &lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">nn&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn.functional&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">F&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">math&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">LoRALinear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">in_features&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">out_features&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">merge&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">rank&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">lora_alpha&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dropout&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.5&lt;/span>&lt;span class="p">):&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">LoRALinear&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">in_features&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">in_features&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">out_features&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">out_features&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">merge&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">merge&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rank&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">rank&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout_rate&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">dropout&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">lora_alpha&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">lora_alpha&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">linear&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">in_features&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">out_features&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">rank&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">lora_b&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Parameter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zeros&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">out_features&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">rank&lt;/span>&lt;span class="p">))&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">lora_a&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Parameter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zeros&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">rank&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">in_features&lt;/span>&lt;span class="p">))&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">scale&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">lora_alpha&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rank&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">linear&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">requires_grad&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="kc">False&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout_rate&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout_rate&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">else&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Identity&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">initial_weights&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">initial_weights&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">init&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">kaiming_uniform_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">lora_a&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">a&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">math&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sqrt&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">5&lt;/span>&lt;span class="p">))&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">init&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zeros_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">lora_b&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rank&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="mi">0&lt;/span> &lt;span class="ow">and&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">merge&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">linear&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">lora_b&lt;/span> &lt;span class="o">@&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">lora_a&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">scale&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">linear&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bias&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">output&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">else&lt;/span>&lt;span class="p">:&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span>&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="err">​&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="量化quantization">量化(quantization)
&lt;/h2>&lt;p>降低精度，减少模型大小，提升推理速度。&lt;/p>
&lt;h2 id="蒸馏distillation">蒸馏(distillation)
&lt;/h2>&lt;p>让小模型学习大模型的知识，提升小模型的性能。&lt;/p>
&lt;h2 id="部署deployment">部署(deployment)
&lt;/h2>&lt;h3 id="vllm部署">vllm部署
&lt;/h3>&lt;h2 id="评估evaluation">评估(evaluation)
&lt;/h2></description></item><item><title>wandb</title><link>https://hych0317.github.io/hugoweb_auto/p/wandb/</link><pubDate>Thu, 19 Sep 2024 14:36:30 +0800</pubDate><guid>https://hych0317.github.io/hugoweb_auto/p/wandb/</guid><description>&lt;blockquote>
&lt;p>(&lt;a class="link" href="https://docs.wandb.ai/quickstart/" target="_blank" rel="noopener"
>https://docs.wandb.ai/quickstart/&lt;/a>)&lt;/p>
&lt;/blockquote>
&lt;h2 id="安装库">安装库
&lt;/h2>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">pip&lt;/span> &lt;span class="n">install&lt;/span> &lt;span class="n">wandb&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 创建账户&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">wandb&lt;/span> &lt;span class="n">login&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 初始化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Inside my model training code&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">wandb&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">init&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">project&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;my-project&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 声明超参数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dropout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mf">0.2&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">hidden_layer_size&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">128&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 记录日志&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">my_train_loop&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">epoch&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span> &lt;span class="c1"># change as appropriate :)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">log&lt;/span>&lt;span class="p">({&lt;/span>&lt;span class="s1">&amp;#39;epoch&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">epoch&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;loss&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">loss&lt;/span>&lt;span class="p">})&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 保存文件&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># by default, this will save to a new subfolder for files associated&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># with your run, created in wandb.run.dir (which is ./wandb by default)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;mymodel.h5&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># you can pass the full path to the Keras model API&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">os&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">path&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">join&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">run&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dir&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s2">&amp;#34;mymodel.h5&amp;#34;&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>使用wandb以后，模型输出，log和要保存的文件将会同步到cloud。&lt;/p>
&lt;h2 id="pytorch应用wandb">PyTorch应用wandb
&lt;/h2>&lt;p>我们以一个最简单的神经网络为例展示wandb的用法：&lt;/p>
&lt;h3 id="首先导入必要的库">首先导入必要的库
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">__future__&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">print_function&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">argparse&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">random&lt;/span> &lt;span class="c1"># to set the python random seed&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">numpy&lt;/span> &lt;span class="c1"># to set the numpy random seed&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn.functional&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">F&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.optim&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">optim&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torchvision&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">datasets&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">transforms&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils.data&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">DataLoader&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Ignore excessive warnings&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">logging&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">logging&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">propagate&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="kc">False&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">logging&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">getLogger&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">setLevel&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">logging&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ERROR&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># WandB – Import the wandb library&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">wandb&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="使用前的准备">使用前的准备
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt"> 10
&lt;/span>&lt;span class="lnt"> 11
&lt;/span>&lt;span class="lnt"> 12
&lt;/span>&lt;span class="lnt"> 13
&lt;/span>&lt;span class="lnt"> 14
&lt;/span>&lt;span class="lnt"> 15
&lt;/span>&lt;span class="lnt"> 16
&lt;/span>&lt;span class="lnt"> 17
&lt;/span>&lt;span class="lnt"> 18
&lt;/span>&lt;span class="lnt"> 19
&lt;/span>&lt;span class="lnt"> 20
&lt;/span>&lt;span class="lnt"> 21
&lt;/span>&lt;span class="lnt"> 22
&lt;/span>&lt;span class="lnt"> 23
&lt;/span>&lt;span class="lnt"> 24
&lt;/span>&lt;span class="lnt"> 25
&lt;/span>&lt;span class="lnt"> 26
&lt;/span>&lt;span class="lnt"> 27
&lt;/span>&lt;span class="lnt"> 28
&lt;/span>&lt;span class="lnt"> 29
&lt;/span>&lt;span class="lnt"> 30
&lt;/span>&lt;span class="lnt"> 31
&lt;/span>&lt;span class="lnt"> 32
&lt;/span>&lt;span class="lnt"> 33
&lt;/span>&lt;span class="lnt"> 34
&lt;/span>&lt;span class="lnt"> 35
&lt;/span>&lt;span class="lnt"> 36
&lt;/span>&lt;span class="lnt"> 37
&lt;/span>&lt;span class="lnt"> 38
&lt;/span>&lt;span class="lnt"> 39
&lt;/span>&lt;span class="lnt"> 40
&lt;/span>&lt;span class="lnt"> 41
&lt;/span>&lt;span class="lnt"> 42
&lt;/span>&lt;span class="lnt"> 43
&lt;/span>&lt;span class="lnt"> 44
&lt;/span>&lt;span class="lnt"> 45
&lt;/span>&lt;span class="lnt"> 46
&lt;/span>&lt;span class="lnt"> 47
&lt;/span>&lt;span class="lnt"> 48
&lt;/span>&lt;span class="lnt"> 49
&lt;/span>&lt;span class="lnt"> 50
&lt;/span>&lt;span class="lnt"> 51
&lt;/span>&lt;span class="lnt"> 52
&lt;/span>&lt;span class="lnt"> 53
&lt;/span>&lt;span class="lnt"> 54
&lt;/span>&lt;span class="lnt"> 55
&lt;/span>&lt;span class="lnt"> 56
&lt;/span>&lt;span class="lnt"> 57
&lt;/span>&lt;span class="lnt"> 58
&lt;/span>&lt;span class="lnt"> 59
&lt;/span>&lt;span class="lnt"> 60
&lt;/span>&lt;span class="lnt"> 61
&lt;/span>&lt;span class="lnt"> 62
&lt;/span>&lt;span class="lnt"> 63
&lt;/span>&lt;span class="lnt"> 64
&lt;/span>&lt;span class="lnt"> 65
&lt;/span>&lt;span class="lnt"> 66
&lt;/span>&lt;span class="lnt"> 67
&lt;/span>&lt;span class="lnt"> 68
&lt;/span>&lt;span class="lnt"> 69
&lt;/span>&lt;span class="lnt"> 70
&lt;/span>&lt;span class="lnt"> 71
&lt;/span>&lt;span class="lnt"> 72
&lt;/span>&lt;span class="lnt"> 73
&lt;/span>&lt;span class="lnt"> 74
&lt;/span>&lt;span class="lnt"> 75
&lt;/span>&lt;span class="lnt"> 76
&lt;/span>&lt;span class="lnt"> 77
&lt;/span>&lt;span class="lnt"> 78
&lt;/span>&lt;span class="lnt"> 79
&lt;/span>&lt;span class="lnt"> 80
&lt;/span>&lt;span class="lnt"> 81
&lt;/span>&lt;span class="lnt"> 82
&lt;/span>&lt;span class="lnt"> 83
&lt;/span>&lt;span class="lnt"> 84
&lt;/span>&lt;span class="lnt"> 85
&lt;/span>&lt;span class="lnt"> 86
&lt;/span>&lt;span class="lnt"> 87
&lt;/span>&lt;span class="lnt"> 88
&lt;/span>&lt;span class="lnt"> 89
&lt;/span>&lt;span class="lnt"> 90
&lt;/span>&lt;span class="lnt"> 91
&lt;/span>&lt;span class="lnt"> 92
&lt;/span>&lt;span class="lnt"> 93
&lt;/span>&lt;span class="lnt"> 94
&lt;/span>&lt;span class="lnt"> 95
&lt;/span>&lt;span class="lnt"> 96
&lt;/span>&lt;span class="lnt"> 97
&lt;/span>&lt;span class="lnt"> 98
&lt;/span>&lt;span class="lnt"> 99
&lt;/span>&lt;span class="lnt">100
&lt;/span>&lt;span class="lnt">101
&lt;/span>&lt;span class="lnt">102
&lt;/span>&lt;span class="lnt">103
&lt;/span>&lt;span class="lnt">104
&lt;/span>&lt;span class="lnt">105
&lt;/span>&lt;span class="lnt">106
&lt;/span>&lt;span class="lnt">107
&lt;/span>&lt;span class="lnt">108
&lt;/span>&lt;span class="lnt">109
&lt;/span>&lt;span class="lnt">110
&lt;/span>&lt;span class="lnt">111
&lt;/span>&lt;span class="lnt">112
&lt;/span>&lt;span class="lnt">113
&lt;/span>&lt;span class="lnt">114
&lt;/span>&lt;span class="lnt">115
&lt;/span>&lt;span class="lnt">116
&lt;/span>&lt;span class="lnt">117
&lt;/span>&lt;span class="lnt">118
&lt;/span>&lt;span class="lnt">119
&lt;/span>&lt;span class="lnt">120
&lt;/span>&lt;span class="lnt">121
&lt;/span>&lt;span class="lnt">122
&lt;/span>&lt;span class="lnt">123
&lt;/span>&lt;span class="lnt">124
&lt;/span>&lt;span class="lnt">125
&lt;/span>&lt;span class="lnt">126
&lt;/span>&lt;span class="lnt">127
&lt;/span>&lt;span class="lnt">128
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 登陆你的wandb账户：&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># WandB – Login to your wandb account so you can log all your metrics&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="err">!&lt;/span>&lt;span class="n">wandb&lt;/span> &lt;span class="n">login&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 定义Convolutional Neural Network：&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">Net&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Net&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># In our constructor, we define our neural network architecture that we&amp;#39;ll use in the forward pass.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Conv2d() adds a convolution layer that generates 2 dimensional feature maps&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># to learn different aspects of our image.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">6&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Linear(x,y) creates dense, fully connected layers with x inputs and y outputs.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Linear layers simply output the dot product of our inputs and weights.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">16&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">120&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">120&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">84&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc3&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">84&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Here we feed the feature maps from the convolutional layers into a max_pool2d layer.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># The max_pool2d layer reduces the size of the image representation our convolutional layers learnt,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># and in doing so it reduces the number of parameters and computations the network needs to perform.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Finally we apply the relu activation function which gives us max(0, max_pool2d_output)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max_pool2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max_pool2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Reshapes x into size (-1, 16 * 5 * 5)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># so we can feed the convolution layer outputs into our fully connected layer.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># We apply the relu activation function and dropout to the output of our fully connected layers.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc3&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Finally we apply the softmax function to squash the probabilities of each class (0-9)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># and ensure they add to 1.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">log_softmax&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dim&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 定义训练函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">train&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">config&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">device&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">train_loader&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">epoch&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># switch model to training mode. This is necessary for layers like dropout, batchNorm etc.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># which behave differently in training and evaluation mode.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># we loop over the data iterator, and feed the inputs to the network and adjust the weights.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">batch_id&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">target&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">train_loader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">batch_id&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="mi">20&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">break&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Loop the input features and labels from the training dataset.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">data&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">target&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">target&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Reset the gradients to 0 for all learnable weight parameters&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zero_grad&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Forward pass: Pass image data from training dataset, make predictions&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># about class image belongs to (0-9 in this case).&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Define our loss function, and compute the loss&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">nll_loss&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">target&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Backward pass:compute the gradients of loss,the model&amp;#39;s parameters&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># update the neural network weights&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">step&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 定义测试函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># wandb.log用来记录一些日志(accuracy,loss and epoch), 便于随时查看网路的性能&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">test&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">args&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">device&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">test_loader&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">classes&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># switch model to evaluation mode.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># This is necessary for layers like dropout, batchNorm etc. which behave differently in training and evaluation mode&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">test_loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">correct&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">example_images&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">[]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">with&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">no_grad&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">data&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">target&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">test_loader&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Load the input features and labels from the test dataset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">data&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">target&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">target&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Make predictions: Pass image data from test dataset,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># make predictions about class image belongs to(0-9 in this case)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Compute the loss sum up batch loss&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">test_loss&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">nll_loss&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">target&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">reduction&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;sum&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Get the index of the max log-probability&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">pred&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">output&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">keepdim&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)[&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">correct&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="n">pred&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eq&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">target&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view_as&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">pred&lt;/span>&lt;span class="p">))&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Log images in your test dataset automatically,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># along with predicted and true labels by passing pytorch tensors with image data into wandb.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">example_images&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">append&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Image&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">data&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">caption&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;Pred:&lt;/span>&lt;span class="si">{}&lt;/span>&lt;span class="s2"> Truth:&lt;/span>&lt;span class="si">{}&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">format&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">classes&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">pred&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">()],&lt;/span> &lt;span class="n">classes&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">target&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]])))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># wandb.log(a_dict) logs the keys and values of the dictionary passed in and associates the values with a step.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># You can log anything by passing it to wandb.log(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># including histograms, custom matplotlib objects, images, video, text, tables, html, pointclounds and other 3D objects.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Here we use it to log test accuracy, loss and some test images (along with their true and predicted labels).&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">log&lt;/span>&lt;span class="p">({&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;Examples&amp;#34;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">example_images&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;Test Accuracy&amp;#34;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="mf">100.&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">correct&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">test_loader&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dataset&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;Test Loss&amp;#34;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">test_loss&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">})&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 初始化一个wandb run，并设置超参数：&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Initialize a new run&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">init&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">project&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;pytorch-intro&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">watch_called&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="kc">False&lt;/span> &lt;span class="c1"># Re-run the model without restarting the runtime, unnecessary after our next release&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># config is a variable that holds and saves hyper parameters and inputs&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">config&lt;/span> &lt;span class="c1"># Initialize config&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">batch_size&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">4&lt;/span> &lt;span class="c1"># input batch size for training (default:64)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">test_batch_size&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">10&lt;/span> &lt;span class="c1"># input batch size for testing(default:1000)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">epochs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">50&lt;/span> &lt;span class="c1"># number of epochs to train(default:10)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">lr&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mf">0.1&lt;/span> &lt;span class="c1"># learning rate(default:0.01)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">momentum&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mf">0.1&lt;/span> &lt;span class="c1"># SGD momentum(default:0.5)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">no_cuda&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="kc">False&lt;/span> &lt;span class="c1"># disables CUDA training&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">seed&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">42&lt;/span> &lt;span class="c1"># random seed(default:42)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">log_interval&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">10&lt;/span> &lt;span class="c1"># how many batches to wait before logging training status&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="主函数">主函数
&lt;/h3>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;span class="lnt">36
&lt;/span>&lt;span class="lnt">37
&lt;/span>&lt;span class="lnt">38
&lt;/span>&lt;span class="lnt">39
&lt;/span>&lt;span class="lnt">40
&lt;/span>&lt;span class="lnt">41
&lt;/span>&lt;span class="lnt">42
&lt;/span>&lt;span class="lnt">43
&lt;/span>&lt;span class="lnt">44
&lt;/span>&lt;span class="lnt">45
&lt;/span>&lt;span class="lnt">46
&lt;/span>&lt;span class="lnt">47
&lt;/span>&lt;span class="lnt">48
&lt;/span>&lt;span class="lnt">49
&lt;/span>&lt;span class="lnt">50
&lt;/span>&lt;span class="lnt">51
&lt;/span>&lt;span class="lnt">52
&lt;/span>&lt;span class="lnt">53
&lt;/span>&lt;span class="lnt">54
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">main&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">use_cuda&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="ow">not&lt;/span> &lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">no_cuda&lt;/span> &lt;span class="ow">and&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">is_available&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">device&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;cuda:0&amp;#34;&lt;/span> &lt;span class="k">if&lt;/span> &lt;span class="n">use_cuda&lt;/span> &lt;span class="k">else&lt;/span> &lt;span class="s2">&amp;#34;cpu&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">kwargs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="s1">&amp;#39;num_workers&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;pin_memory&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="kc">True&lt;/span>&lt;span class="p">}&lt;/span> &lt;span class="k">if&lt;/span> &lt;span class="n">use_cuda&lt;/span> &lt;span class="k">else&lt;/span> &lt;span class="p">{}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Set random seeds and deterministic pytorch for reproducibility&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># random.seed(config.seed) # python random seed&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">manual_seed&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">seed&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># pytorch random seed&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># numpy.random.seed(config.seed) # numpy random seed&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backends&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cudnn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">deterministic&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="kc">True&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Load the dataset: We&amp;#39;re training our CNN on CIFAR10.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># First we define the transformations to apply to our images.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transform&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Compose&lt;/span>&lt;span class="p">([&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Normalize&lt;/span>&lt;span class="p">((&lt;/span>&lt;span class="mf">0.5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.5&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mf">0.5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.5&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Now we load our training and test datasets and apply the transformations defined above&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">train_loader&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">datasets&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CIFAR10&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">root&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;./data&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">train&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">download&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">transform&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">),&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">**&lt;/span>&lt;span class="n">kwargs&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">test_loader&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">datasets&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CIFAR10&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">root&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;./data&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">train&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">False&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">download&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">transform&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">),&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">False&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">**&lt;/span>&lt;span class="n">kwargs&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">classes&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;plane&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;car&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;bird&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;cat&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;deer&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;dog&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;frog&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;horse&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;ship&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;truck&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Initialize our model, recursively go over all modules and convert their parameters&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># and buffers to CUDA tensors (if device is set to cuda)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Net&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">SGD&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">lr&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">momentum&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">momentum&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># wandb.watch() automatically fetches all layer dimensions, gradients, model parameters&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># and logs them automatically to your dashboard.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># using log=&amp;#34;all&amp;#34; log histograms of parameter values in addition to gradients&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">watch&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">log&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;all&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">epoch&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">config&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">epochs&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">train&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">config&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">device&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">train_loader&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">epoch&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">test&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">config&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">device&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">test_loader&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">classes&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Save the model checkpoint. This automatically saves a file to the cloud&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="s1">&amp;#39;model.h5&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">wandb&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;model.h5&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">if&lt;/span> &lt;span class="vm">__name__&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="s1">&amp;#39;__main__&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">main&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div></description></item><item><title>LLM趋势及个人生涯分享--李沐</title><link>https://hych0317.github.io/hugoweb_auto/p/llm%E8%B6%8B%E5%8A%BF%E5%8F%8A%E4%B8%AA%E4%BA%BA%E7%94%9F%E6%B6%AF%E5%88%86%E4%BA%AB--%E6%9D%8E%E6%B2%90/</link><pubDate>Sat, 24 Aug 2024 21:01:10 +0800</pubDate><guid>https://hych0317.github.io/hugoweb_auto/p/llm%E8%B6%8B%E5%8A%BF%E5%8F%8A%E4%B8%AA%E4%BA%BA%E7%94%9F%E6%B6%AF%E5%88%86%E4%BA%AB--%E6%9D%8E%E6%B2%90/</guid><description>&lt;h2 id="第一部分llm">第一部分：LLM
&lt;/h2>&lt;p>继上一次的深度学习浪潮，现在兴起的是语言模型浪潮。&lt;/p>
&lt;p>语言模型可以分为三块：算力、数据和算法。(丹炉、药材、炼丹术)&lt;/p>
&lt;h3 id="算力">算力
&lt;/h3>&lt;h4 id="算力的发展趋势">算力的发展趋势
&lt;/h4>&lt;p>长期来看，算力的成本将持续降低，大模型训练成本也会不断下降，所以大模型本身也不是一个能保值的东西，价值会随着时间降低，也某种意义上受摩尔定律的影响(训练会两倍两倍地变便宜，今天训练一个模型，一年之后它的价值会减半)。&lt;/p>
&lt;p>&lt;strong>大模型不是特别有性价比的东西。你要想清楚，从长期来看，你的模型能带来什么价值，让你能够保值。&lt;/strong>&lt;/p>
&lt;p>100B 到 500B 会是未来主流的一个大势。你可以做更大，但是它很多时候是用 MoE 做的，它的有效大小（每次激活的大小）可能也就是 500B 的样子。&lt;/p>
&lt;p>PS：对于模型开发者来说，要想在模型的参数量、预训练上想取得突破是不明智的，头部企业的突破速度不是小团队可以想象的。更应该关注模型的后训练、应用部署。&lt;/p>
&lt;h4 id="算力的瓶颈">算力的瓶颈
&lt;/h4>&lt;p>通信带宽：分布式训练要求多卡间的通信延迟足够小。由于供电、散热的问题，芯片间需要保持一定距离，导致了通信延迟。&lt;br>
即使是使用光纤以光速传输，一米距离传输带来的几纳秒延迟对性能影响也很大。&lt;/p>
&lt;p>内存：大模型训练需要大量的内存，目前最大的单个内存芯片可以达到192GB。&lt;br>
内存占面积大，一个芯片划一块给算力，划一块给内存之后就放不下什么东西了。这会严重限制模型大小。&lt;br>
(在这一块，虽然英伟达是领先者，但其实英伟达是不如 AMD 的，甚至不如 Google 的 TPU)&lt;/p>
&lt;h3 id="数据">数据
&lt;/h3>&lt;p>当前数据&lt;strong>质量&lt;/strong>的提升比数量提升更重要。&lt;/p>
&lt;p>数据决定了模型的上限，算法决定了模型的下限。&lt;br>
就目前来说，我们离 AGI 还很远， AGI 能够做自主的学习，我们目前的模型就是填鸭式状态。&lt;br>
Claude团队花了很大的力气来做数据，在数据上用了很多年。所以，想让模型在某一个方面做得特别好，需要先把相关数据准备好。大家还是用了 70-80% 时间在数据上。&lt;/p>
&lt;p>数据质量的提升还来自于：&lt;/p>
&lt;ol>
&lt;li>标注数据：语言模型的训练数据需要大量的标注数据，比如训练集、验证集、测试集。&lt;/li>
&lt;li>领域数据：语言模型的训练数据需要包含领域相关的文本，比如新闻、科技、医疗等。&lt;/li>
&lt;li>多样性：语言模型的训练数据需要包含各种语言、方言、口音等多样性的文本。&lt;/li>
&lt;/ol>
&lt;h3 id="趋势">趋势
&lt;/h3>&lt;p>End-to-end和多模态是当前大模型的趋势。&lt;/p>
&lt;p>目前语言模型已经达到了较高的水平，大约在 80 到 85 分之间。音频模型在可接受的水平，处于能用阶段，大约在 70-80 分之间。但在视频生成方面，尤其是生成具有特定功能的视频尚显不足，整体水平大约在 50 分左右。&lt;/p>
&lt;h4 id="技术层面">技术层面
&lt;/h4>&lt;h5 id="多模态技术">多模态技术
&lt;/h5>&lt;p>多模态技术的发展趋势在于整合不同类型的模态信息。&lt;br>
由于文本是信息密度最高的，也是最容易获得的。&lt;br>
一是可以借助强大的文本模型进行泛化。二是可以通过文本来定制和控制其他模态的输出，比如用简单的文本指令控制图片、视频和声音的生成。&lt;/p>
&lt;h5 id="预训练与后训练">预训练与后训练
&lt;/h5>&lt;p>预训练是用大量的文本数据训练一个通用语言模型，后训练是用这个通用语言模型来训练特定任务的模型。&lt;/p>
&lt;p>&lt;strong>预训练是工程问题，后训练才是技术问题&lt;/strong>&lt;br>
在预训练方面，现在已经变成一个因为大而导致很多工程问题的困难，这其实还是算法上探索不够，得清楚如何改进算法。&lt;br>
对于后训练，&lt;strong>高质量的数据和改进的算法&lt;/strong>能够极大地提升模型效果。高质量的数据一定是结构化的，并且与应用场景高度相关，以保证数据的多样性和实用性。&lt;/p>
&lt;h5 id="垂直模型也需要通用维度">垂直模型也需要通用维度
&lt;/h5>&lt;p>为什么要做垂直模型呢？因为通用模型的问题还是一个指数问题，你要实现的任务，通用模型不一定能完成。&lt;br>
通用模型是通用维度，需要各个方面都有提升，如果刚好满足你的要求，需要指数级的数据，并且模型会变得很大。&lt;/p>
&lt;p>&lt;strong>但是&lt;/strong>就算是一个很垂直领域的模型，它的通用能力也是不能差的。比如说你要在某一个学科里面拿第一，你别的科目也不能差到哪里去。&lt;/p>
&lt;h5 id="模型评估">模型评估
&lt;/h5>&lt;p>自然语言很难评价其正确性、逻辑性和风格。通常我们不想让人来评估，因为比较昂贵，但使用模型评估会带来偏差。&lt;br>
有一个好的评估可以解决 50% 的问题。因为一旦评估解决了，那你就能够进行优化。第二评估解决了，表示你拥有了一些数据。&lt;/p>
&lt;h4 id="应用层面">应用层面
&lt;/h4>&lt;h5 id="人机交互">人机交互
&lt;/h5>&lt;p>人机交互的方式可能会发生改变。以往人机交互都是通过键鼠和屏幕完成的，未来的语音控制系统将能够处理更加复杂和具体的任务。&lt;br>
手机的 killer APP 是什么吗？短视频。语言模型的killer APP是什么？这个还是未知。&lt;/p>
&lt;h5 id="对人类的替代">对人类的替代
&lt;/h5>&lt;p>数据越多的领域，就越能被自动化。&lt;br>
当前大模型在简单的文科任务上已经能很好地代替人类。因为文科任务是最能简单快速采集大量数据的。在简单理科任务和复杂文科任务上能力正在突破。&lt;br>
而当前想要替代蓝领，还非常遥远。工厂需要投放大量传感器，做好数字化基础设施建设，数据收集和整理方案成熟起来，才有大模型落地的希望。而这一切当前看来还很难，但一旦实现就会是重大变革。&lt;/p>
&lt;h2 id="第二部分职业规划建议">第二部分：职业规划建议
&lt;/h2>&lt;h3 id="几种身份的区别">几种身份的区别
&lt;/h3>&lt;p>&lt;strong>目标和动机的差异&lt;/strong>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/LLM_trend/diff.png"
loading="lazy"
alt="差别"
>
大厂的目标是升职加薪，PhD的目标就是博士毕业，创业的目标就是套现退出&lt;/p>
&lt;p>&lt;strong>优缺点分析&lt;/strong>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/LLM_trend/dgr.png"
loading="lazy"
alt="打工人"
>
晚上不用做噩梦，但逐渐成为螺丝钉。&lt;br>
好处是，可以在一个相对简单的环境里学习各种从业知识、有相对稳定的收入和空余时间；坏处就是停留在打工人或者职业经理人的思维。&lt;/p>
&lt;p>&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/LLM_trend/phd.png"
loading="lazy"
alt="PhD"
>
好处是，在几年的时间里可以专心探索某一个领域；&lt;br>
坏处是，很少有实验室能参与大项目的研发，并且需要有很强的自我驱动力。要真的热爱研究，不然坚持不下去，你会觉得研究这个东西到底有什么意义，写这篇论文要干嘛。&lt;br>
其实，你可以这样想：我写这篇文章就是为了练习写作，等到更厉害、更大的成果做出来后，写作不能给我拉后腿。你要有一个更远大的目标，是真的热爱它。&lt;br>
PS：读不了一点&lt;/p>
&lt;p>&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/LLM_trend/startup.png"
loading="lazy"
alt="创业"
>&lt;/p>
&lt;h3 id="驱动力的来源">驱动力的来源
&lt;/h3>&lt;p>&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/LLM_trend/motivation.png"
loading="lazy"
alt="动机"
>
欲望是越底层越好，名、利、权，都是底层的欲望。恐惧是可以让你抑郁的恐惧，也是让你感受到生死的恐惧。&lt;/p>
&lt;p>你需要把欲望和恐惧转变成积极向上的动机，你的动机一定是正确的，符合价值观的，因为逃避、放纵满足不了欲望，也缓解不了恐惧，唯一克服它的办法是，把它变成一个积极向上、符合社会价值的一个动机。&lt;/p>
&lt;p>&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/LLM_trend/motidiff.png"
loading="lazy"
alt="动机差异"
>
有了动机之后就得想，我要解决什么问题，你的问题可能就是你的动机本身。&lt;/p>
&lt;p>举例来说，语言模型为什么能运作？没人知道，这是一个很有学术价值的东西。语言模型能不能孵化出新的应用？这是商业价值上的问题。实在不行的话，也可以思考语言模型在某个产品上如何落地。&lt;/p>
&lt;p>&lt;strong>一个提升自我的方法&lt;/strong>
&lt;img src="https://hych0317.github.io/hugoweb_auto/hugoweb_auto/post/deep-learning/LLM_trend/improve.png"
loading="lazy"
alt="方法"
>
为什么目标没达成？&lt;/p>
&lt;p>可能是因为懒，那么你得直面懒的问题。我怎么能让自己勤奋一点？找一个学习伙伴，每天在图书馆待着，要大家相互监督等。&lt;br>
还有可能是因为蠢，这就有两种解决方案。一种是换一个方向，去擅长的领域；一种是既然绕不开，那就花别人两倍的时间。&lt;br>
无论是因为懒还是蠢，你都得对自己狠，最后拼的就是你对自己有多狠。&lt;/p>
&lt;p>你要形成一个习惯，定个闹钟，每周一晚上花 30 分钟对自己进行总结，每个季度要总结，翻看之前你的写的周记，看看这个季度的目标是否完成，下个季度要做什么。&lt;/p>
&lt;p>选择比努力更重要，但选择的前提是搞清楚你的目标是什么。&lt;/p></description></item></channel></rss>