Inference with C# BERT NLP Deep Learning and ONNX Runtime
效果
测试一
Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.
Question :What is his name?
测试二
Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.
Question :What will he bring home?
测试三
Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.
Question :Where is Bob?
模型信息
Inputs
-------------------------
name:unique_ids_raw_output___9:0
tensor:Int64[-1]
name:segment_ids:0
tensor:Int64[-1, 256]
name:input_mask:0
tensor:Int64[-1, 256]
name:input_ids:0
tensor:Int64[-1, 256]
---------------------------------------------------------------
Outputs
-------------------------
name:unstack:1
tensor:Float[-1, 256]
name:unstack:0
tensor:Float[-1, 256]
name:unique_ids:0
tensor:Int64[-1]
---------------------------------------------------------------
项目
代码
using BERTTokenizers;
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Windows.Forms;namespace Inference_with_C__BERT_NLP_Deep_Learning_and_ONNX_Runtime
{
public struct BertInput
{
public long[] InputIds { get; set; }
public long[] InputMask { get; set; }
public long[] SegmentIds { get; set; }
public long[] UniqueIds { get; set; }
} public partial class Form1 : Form
{
public Form1()
{
InitializeComponent();
} RunOptions runOptions;
InferenceSession session;
BertUncasedLargeTokenizer tokenizer;
Stopwatch stopWatch = new Stopwatch(); private void Form1_Load(object sender, EventArgs e)
{
string modelPath = "bertsquad-10.onnx";
runOptions = new RunOptions();
session = new InferenceSession(modelPath);
tokenizer = new BertUncasedLargeTokenizer();
} int MaxAnswerLength = 30;
int bestN = 20; private void button1_Click(object sender, EventArgs e)
{
txt_answer.Text = "";
Application.DoEvents(); string question = txt_question.Text.Trim();
string context = txt_context.Text.Trim(); // Get the sentence tokens.
var tokens = tokenizer.Tokenize(question, context); // Encode the sentence and pass in the count of the tokens in the sentence.
var encoded = tokenizer.Encode(tokens.Count(), question, context); var padding = Enumerable
.Repeat(0L, 256 - tokens.Count)
.ToList(); var bertInput = new BertInput()
{
InputIds = encoded.Select(t => t.InputIds).Concat(padding).ToArray(),
InputMask = encoded.Select(t => t.AttentionMask).Concat(padding).ToArray(),
SegmentIds = encoded.Select(t => t.TokenTypeIds).Concat(padding).ToArray(),
UniqueIds = new long[] { 0 }
}; // Create input tensors over the input data.
var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,
new long[] { 1, bertInput.InputIds.Length }); var inputMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputMask,
new long[] { 1, bertInput.InputMask.Length }); var segmentIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.SegmentIds,
new long[] { 1, bertInput.SegmentIds.Length }); var uniqueIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.UniqueIds,
new long[] { bertInput.UniqueIds.Length }); var inputs = new Dictionary<string, OrtValue>
{
{ "unique_ids_raw_output___9:0", uniqueIdsOrtValue },
{ "segment_ids:0", segmentIdsOrtValue},
{ "input_mask:0", inputMaskOrtValue },
{ "input_ids:0", inputIdsOrtValue }
}; stopWatch.Restart();
// Run session and send the input data in to get inference output.
var output = session.Run(runOptions, inputs, session.OutputNames);
stopWatch.Stop(); var startLogits = output[1].GetTensorDataAsSpan<float>();
var endLogits = output[0].GetTensorDataAsSpan<float>();
var uniqueIds = output[2].GetTensorDataAsSpan<long>();
var contextStart = tokens.FindIndex(o => o.Token == "[SEP]");
var bestStartLogits = startLogits.ToArray()
.Select((logit, index) => (Logit: logit, Index: index))
.OrderByDescending(o => o.Logit)
.Take(bestN); var bestEndLogits = endLogits.ToArray()
.Select((logit, index) => (Logit: logit, Index: index))
.OrderByDescending(o => o.Logit)
.Take(bestN); var bestResultsWithScore = bestStartLogits
.SelectMany(startLogit =>
bestEndLogits
.Select(endLogit =>
(
StartLogit: startLogit.Index,
EndLogit: endLogit.Index,
Score: startLogit.Logit + endLogit.Logit
)
)
)
.Where(entry => !(entry.EndLogit < entry.StartLogit || entry.EndLogit - entry.StartLogit > MaxAnswerLength || entry.StartLogit == 0 && entry.EndLogit == 0 || entry.StartLogit < contextStart))
.Take(bestN); var (item, probability) = bestResultsWithScore
.Softmax(o => o.Score)
.OrderByDescending(o => o.Probability)
.FirstOrDefault(); int startIndex = item.StartLogit;
int endIndex = item.EndLogit; var predictedTokens = tokens
.Skip(startIndex)
.Take(endIndex + 1 - startIndex)
.Select(o => tokenizer.IdToToken((int)o.VocabularyIndex))
.ToList(); // Print the result.
string answer = "answer:" + String.Join(" ", StitchSentenceBackTogether(predictedTokens))
+ "\r\nprobability:" + probability
+ $"\r\n推理耗时:{stopWatch.ElapsedMilliseconds}毫秒"; txt_answer.Text = answer;
Console.WriteLine(answer); }
private List<string> StitchSentenceBackTogether(List<string> tokens)
{
var currentToken = string.Empty; tokens.Reverse();
var tokensStitched = new List<string>();
foreach (var token in tokens)
{
if (!token.StartsWith("##"))
{
currentToken = token + currentToken;
tokensStitched.Add(currentToken);
currentToken = string.Empty;
}
else
{
currentToken = token.Replace("##", "") + currentToken;
}
} tokensStitched.Reverse();
return tokensStitched;
}
}
}
using BERTTokenizers;
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Windows.Forms;
namespace Inference_with_C__BERT_NLP_Deep_Learning_and_ONNX_Runtime
{
public struct BertInput
{
public long[] InputIds { get; set; }
public long[] InputMask { get; set; }
public long[] SegmentIds { get; set; }
public long[] UniqueIds { get; set; }
}
public partial class Form1 : Form
{
public Form1()
{
InitializeComponent();
}
RunOptions runOptions;
InferenceSession session;
BertUncasedLargeTokenizer tokenizer;
Stopwatch stopWatch = new Stopwatch();
private void Form1_Load(object sender, EventArgs e)
{
string modelPath = "bertsquad-10.onnx";
runOptions = new RunOptions();
session = new InferenceSession(modelPath);
tokenizer = new BertUncasedLargeTokenizer();
}
int MaxAnswerLength = 30;
int bestN = 20;
private void button1_Click(object sender, EventArgs e)
{
txt_answer.Text = "";
Application.DoEvents();
string question = txt_question.Text.Trim();
string context = txt_context.Text.Trim();
// Get the sentence tokens.
var tokens = tokenizer.Tokenize(question, context);
// Encode the sentence and pass in the count of the tokens in the sentence.
var encoded = tokenizer.Encode(tokens.Count(), question, context);
var padding = Enumerable
.Repeat(0L, 256 - tokens.Count)
.ToList();
var bertInput = new BertInput()
{
InputIds = encoded.Select(t => t.InputIds).Concat(padding).ToArray(),
InputMask = encoded.Select(t => t.AttentionMask).Concat(padding).ToArray(),
SegmentIds = encoded.Select(t => t.TokenTypeIds).Concat(padding).ToArray(),
UniqueIds = new long[] { 0 }
};
// Create input tensors over the input data.
var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,
new long[] { 1, bertInput.InputIds.Length });
var inputMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputMask,
new long[] { 1, bertInput.InputMask.Length });
var segmentIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.SegmentIds,
new long[] { 1, bertInput.SegmentIds.Length });
var uniqueIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.UniqueIds,
new long[] { bertInput.UniqueIds.Length });
var inputs = new Dictionary<string, OrtValue>
{
{ "unique_ids_raw_output___9:0", uniqueIdsOrtValue },
{ "segment_ids:0", segmentIdsOrtValue},
{ "input_mask:0", inputMaskOrtValue },
{ "input_ids:0", inputIdsOrtValue }
};
stopWatch.Restart();
// Run session and send the input data in to get inference output.
var output = session.Run(runOptions, inputs, session.OutputNames);
stopWatch.Stop();
var startLogits = output[1].GetTensorDataAsSpan<float>();
var endLogits = output[0].GetTensorDataAsSpan<float>();
var uniqueIds = output[2].GetTensorDataAsSpan<long>();
var contextStart = tokens.FindIndex(o => o.Token == "[SEP]");
var bestStartLogits = startLogits.ToArray()
.Select((logit, index) => (Logit: logit, Index: index))
.OrderByDescending(o => o.Logit)
.Take(bestN);
var bestEndLogits = endLogits.ToArray()
.Select((logit, index) => (Logit: logit, Index: index))
.OrderByDescending(o => o.Logit)
.Take(bestN);
var bestResultsWithScore = bestStartLogits
.SelectMany(startLogit =>
bestEndLogits
.Select(endLogit =>
(
StartLogit: startLogit.Index,
EndLogit: endLogit.Index,
Score: startLogit.Logit + endLogit.Logit
)
)
)
.Where(entry => !(entry.EndLogit < entry.StartLogit || entry.EndLogit - entry.StartLogit > MaxAnswerLength || entry.StartLogit == 0 && entry.EndLogit == 0 || entry.StartLogit < contextStart))
.Take(bestN);
var (item, probability) = bestResultsWithScore
.Softmax(o => o.Score)
.OrderByDescending(o => o.Probability)
.FirstOrDefault();
int startIndex = item.StartLogit;
int endIndex = item.EndLogit;
var predictedTokens = tokens
.Skip(startIndex)
.Take(endIndex + 1 - startIndex)
.Select(o => tokenizer.IdToToken((int)o.VocabularyIndex))
.ToList();
// Print the result.
string answer = "answer:" + String.Join(" ", StitchSentenceBackTogether(predictedTokens))
+ "\r\nprobability:" + probability
+ $"\r\n推理耗时:{stopWatch.ElapsedMilliseconds}毫秒";
txt_answer.Text = answer;
Console.WriteLine(answer);
}
private List<string> StitchSentenceBackTogether(List<string> tokens)
{
var currentToken = string.Empty;
tokens.Reverse();
var tokensStitched = new List<string>();
foreach (var token in tokens)
{
if (!token.StartsWith("##"))
{
currentToken = token + currentToken;
tokensStitched.Add(currentToken);
currentToken = string.Empty;
}
else
{
currentToken = token.Replace("##", "") + currentToken;
}
}
tokensStitched.Reverse();
return tokensStitched;
}
}
}