本日はUnityの小ネタ枠です。
Unity.InferenceEngineパッケージを使ってSLMをUnity上で実行する方法です。
SLMとは
SLM(Small Language Model)は小規模な言語モデルを指します。
LLM(Large Language Model)と比較してサイズが小さくエッジデバイスで動くよう最適化されています。
ただしLLMと比較すると知識量や推論性能は低くなります。
モデルを取得する
今回は以下のRinna Japanese GPT2モデルを利用しました。
ONNX版があるため、Unity.InferenceEngineでそのまま利用できます。
saldra/rinna-japanese-gpt2-xsmall-onnx · Hugging Face

[Files and versions]を開き、onnxフォルダのモデルとtokenizer.jsonをダウンロードします。

Unityにモデルを取り込む際、Inspectorでbatch_size = 1、seequence_length = 64〜128に設定して最適化します。
batch_sizeは入力する文章の数、sequence_lengthは一度に扱えるトークン列(文章)の長さに当たります。
デフォルトだと、batch_size = -1(動的)、sequence_length = -1(動的)に設定されており、推論コンパイラが最適化できません。

Unity.InferenceEngineパッケージを使ってSLMをUnity上で実行する
Unity.InferenceEngineでSLMの言語推論を利用する場合、文字列をテンソルとして受け渡す必要があります。
以下のようなRinnaモデルに合わせて言語推論を行うサンプルスクリプトを作成しました。
・RinnaModelService.cs
using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; using System.Threading; using Cysharp.Threading.Tasks; using Unity.InferenceEngine; using Unity.InferenceEngine.Tokenization; using Unity.InferenceEngine.Tokenization.Decoders; using Unity.InferenceEngine.Tokenization.Mappers; using Unity.InferenceEngine.Tokenization.PreTokenizers; using UnityEngine; using UnityEngine.Networking; using Newtonsoft.Json.Linq; namespace Chatbot.SmallLanguageModel.Application { public class RinnaModelService : ISmallLanguageModelService { private const int MaxSequenceLength = 128; private const int MaxGeneratedTokens = 64; private static readonly int SamplingTopK = 40; private static readonly float SamplingTemperature = 0.8f; private static readonly float RepetitionPenalty = 1.1f; private const string TokenizerFileName = "tokenizer.json"; private const string TokenizerFolderName = "InferenceEngine"; private ModelAsset modelAsset; // decoder_model.onnx を指す private Model _runtimeModel; private Worker _worker; private Tensor<int> _inputTensor; private Tensor<int> _attentionTensor; private Tensor[] _inputSlots; private int[] _inputBuffer; private int[] _attentionBuffer; private int _inputIdsIndex = -1; private int _attentionMaskIndex = -1; private readonly SemaphoreSlim _inferenceLock = new SemaphoreSlim(1, 1); private AsyncLazy<ITokenizerBridge> _tokenizerLoader; private ITokenizerBridge _tokenizer; private int _padTokenId; private int _eosTokenId = -1; private int? _bosTokenId; private readonly System.Random _random = new System.Random(); public void Initialize(ModelAsset modelAsset) { this.modelAsset = modelAsset ?? throw new ArgumentNullException(nameof(modelAsset)); _runtimeModel = ModelLoader.Load(modelAsset); _worker = new Worker(_runtimeModel, Unity.InferenceEngine.DeviceType.GPU); _inputBuffer = new int[MaxSequenceLength]; _attentionBuffer = new int[MaxSequenceLength]; _inputTensor = new Tensor<int>(new TensorShape(1, MaxSequenceLength)); _attentionTensor = new Tensor<int>(new TensorShape(1, MaxSequenceLength)); _inputSlots = new Tensor[_runtimeModel.inputs.Count]; _inputIdsIndex = ResolveInputIndex("input_ids"); _attentionMaskIndex = ResolveInputIndex("attention_mask"); _tokenizerLoader = new AsyncLazy<ITokenizerBridge>(LoadTokenizerAsync); } public async UniTask<string> ProcessInputAsync(string input) { if (_worker == null) { throw new InvalidOperationException("InferenceEngineService is not initialized."); } if (string.IsNullOrWhiteSpace(input)) { return string.Empty; } // rinna/japanese-gpt2-mediumモデル用プロンプト例 var inputMessage = "以下はあなたとわたしの会話です。" + "わたしは丁寧な日本語で、分かりやすく短めに答えます。" + "" + "あなた: " + input + "わたし: "; await _inferenceLock.WaitAsync(); try { var tokenizer = await _tokenizerLoader; var sanitizedInput = inputMessage.Trim(); var promptTokens = await UniTask.RunOnThreadPool(() => PrepareSingleTurnPrompt(sanitizedInput, tokenizer)); var generatedTokens = await GenerateResponseTokensAsync(promptTokens); var response = await UniTask.RunOnThreadPool(() => DecodeResponse(generatedTokens, tokenizer)); Debug.Log($"[InferenceEngineService] Generated response: {response}"); // レスポンスはシンプルにするため、スペースや。?! ... で区切って最初の文だけ返す var delimiters = new[] { " ", "。", "?", "!", "...", "わたし:", "あなた:" }; var firstSentence = response.Split(delimiters, StringSplitOptions.RemoveEmptyEntries).FirstOrDefault() ?? string.Empty; // ただし最初の文が短い(2文字以下)場合は、もう一文追加する if (firstSentence.Length <= 2) { var parts = response.Split(delimiters, StringSplitOptions.RemoveEmptyEntries); if (parts.Length > 1) { firstSentence += delimiters[0] + parts[1]; } } Debug.Log($"[InferenceEngineService] First sentence extracted: {firstSentence}"); return firstSentence; } catch (Exception ex) { Debug.LogError($"[InferenceEngineService] Failed to process input: {ex.Message}\n{ex}"); return "申し訳ありません。回答に失敗しました。"; } finally { _inferenceLock.Release(); } } private int ResolveInputIndex(string inputName) { for (var i = 0; i < _runtimeModel.inputs.Count; i++) { var input = _runtimeModel.inputs[i]; if (string.Equals(input.name, inputName, StringComparison.OrdinalIgnoreCase)) { return i; } } throw new InvalidOperationException($"Model input '{inputName}' was not found."); } private async UniTask<ITokenizerBridge> LoadTokenizerAsync() { string tokenizerJson; #if UNITY_ANDROID && !UNITY_EDITOR tokenizerJson = await LoadTokenizerJsonFromStreamingAssetsAsync(); #else var path = GetTokenizerPath(); if (!File.Exists(path)) { throw new FileNotFoundException($"Tokenizer configuration file not found at {path}."); } tokenizerJson = await UniTask.RunOnThreadPool(() => File.ReadAllText(path, System.Text.Encoding.UTF8)); #endif var tokenizerResult = BuildTokenizer(tokenizerJson); ApplyTokenizerMetadata(tokenizerResult.Metadata); _tokenizer = tokenizerResult.Tokenizer; return tokenizerResult.Tokenizer; } #if UNITY_ANDROID && !UNITY_EDITOR private async UniTask<string> LoadTokenizerJsonFromStreamingAssetsAsync() { var basePath = UnityEngine.Application.streamingAssetsPath; var candidates = new List<string> { $"{basePath}/{TokenizerFolderName}/{TokenizerFileName}", $"{basePath}/{TokenizerFileName}" }; foreach (var candidate in candidates) { using (var request = UnityWebRequest.Get(candidate)) { try { await request.SendWebRequest().ToUniTask(); } catch (Exception ex) { Debug.LogWarning($"[InferenceEngineService] Failed to load tokenizer from {candidate}: {ex.Message}"); continue; } #if UNITY_2020_2_OR_NEWER if (request.result == UnityWebRequest.Result.Success) #else if (!request.isNetworkError && !request.isHttpError) #endif { return request.downloadHandler.text; } Debug.LogWarning($"[InferenceEngineService] Failed to load tokenizer from {candidate}: {request.error}"); } } throw new FileNotFoundException("Tokenizer configuration file could not be loaded from StreamingAssets on Android."); } #endif private void ApplyTokenizerMetadata(TokenizerMetadata metadata) { if (metadata.PadTokenId.HasValue) { _padTokenId = metadata.PadTokenId.Value; } else if (metadata.SpecialTokenByValue.TryGetValue("<pad>", out var inferredPad)) { _padTokenId = inferredPad; } else if (metadata.SpecialTokenByValue.TryGetValue("[PAD]", out var squarePad)) { _padTokenId = squarePad; } else { _padTokenId = 0; } if (metadata.EosTokenId.HasValue) { _eosTokenId = metadata.EosTokenId.Value; } else if (metadata.SpecialTokenByValue.TryGetValue("</s>", out var inferredEos)) { _eosTokenId = inferredEos; } else if (metadata.SpecialTokenByValue.TryGetValue("[SEP]", out var sepEos)) { _eosTokenId = sepEos; } else { _eosTokenId = -1; } if (metadata.BosTokenId.HasValue) { _bosTokenId = metadata.BosTokenId; } else if (metadata.SpecialTokenByValue.TryGetValue("<s>", out var bos)) { _bosTokenId = bos; } else if (metadata.SpecialTokenByValue.TryGetValue("[CLS]", out var cls)) { _bosTokenId = cls; } else { _bosTokenId = null; } } private string GetTokenizerPath() { var streamingAssetsPath = UnityEngine.Application.streamingAssetsPath; #if UNITY_ANDROID && !UNITY_EDITOR return $"{streamingAssetsPath}/{TokenizerFolderName}/{TokenizerFileName}"; #else var candidate = Path.Combine(streamingAssetsPath, TokenizerFolderName, TokenizerFileName); if (File.Exists(candidate)) { return candidate; } candidate = Path.Combine(streamingAssetsPath, TokenizerFileName); return candidate; #endif } private void EnsureBosToken(List<int> tokens) { if (_bosTokenId.HasValue && (tokens.Count == 0 || tokens[0] != _bosTokenId.Value)) { tokens.Insert(0, _bosTokenId.Value); if (tokens.Count > MaxSequenceLength) { tokens.RemoveAt(tokens.Count - 1); } } } private List<int> PrepareSingleTurnPrompt(string userInput, ITokenizerBridge tokenizer) { var encoding = tokenizer.Encode(userInput, addSpecialTokens: true); var ids = new List<int>(encoding.Ids); var attention = new List<int>(encoding.Attention); TrimInactive(ids, attention); if (ids.Count > MaxSequenceLength) { ids = ids.Skip(ids.Count - MaxSequenceLength).Take(MaxSequenceLength).ToList(); } EnsureBosToken(ids); return ids; } private async UniTask<List<int>> GenerateResponseTokensAsync(List<int> promptTokens) { var tokens = new List<int>(promptTokens); var generated = new List<int>(); var occurrence = BuildInitialCounts(tokens); for (var step = 0; step < MaxGeneratedTokens; step++) { await UniTask.SwitchToMainThread(); var effectiveLength = PopulateInputBuffers(tokens); if (effectiveLength <= 0) { break; } _inputTensor.Upload(_inputBuffer); _attentionTensor.Upload(_attentionBuffer); _inputSlots[_inputIdsIndex] = _inputTensor; _inputSlots[_attentionMaskIndex] = _attentionTensor; _worker.Schedule(_inputSlots); var logitsTensor = _worker.PeekOutput("logits") as Tensor<float>; if (logitsTensor == null) { break; } var logits = logitsTensor.DownloadToArray(); var vocabSize = logitsTensor.shape[2]; var lastIndex = Math.Min(effectiveLength, MaxSequenceLength) - 1; if (lastIndex < 0) { break; } await UniTask.SwitchToThreadPool(); var offset = lastIndex * vocabSize; var nextToken = SampleNextToken(logits, offset, vocabSize, occurrence); tokens.Add(nextToken); generated.Add(nextToken); IncrementCount(nextToken, occurrence); if (ShouldStop(nextToken)) { generated.RemoveAt(generated.Count - 1); DecrementCount(nextToken, occurrence); break; } } await UniTask.SwitchToMainThread(); return generated; } private int PopulateInputBuffers(List<int> tokens) { Array.Fill(_inputBuffer, _padTokenId); Array.Clear(_attentionBuffer, 0, _attentionBuffer.Length); var length = tokens.Count; if (length == 0) { return 0; } var window = Math.Min(length, MaxSequenceLength); var start = length - window; for (var i = 0; i < window; i++) { _inputBuffer[i] = tokens[start + i]; _attentionBuffer[i] = 1; } return window; } private string DecodeResponse(List<int> tokens, ITokenizerBridge tokenizer) { if (tokens.Count == 0) { return string.Empty; } var decoded = tokenizer.Decode(tokens, skipSpecialTokens: true); return decoded; } private bool ShouldStop(int tokenId) { if (_eosTokenId >= 0 && tokenId == _eosTokenId) { return true; } if (_padTokenId >= 0 && tokenId == _padTokenId) { return true; } return false; } private Dictionary<int, int> BuildInitialCounts(List<int> tokens) { var counts = new Dictionary<int, int>(); foreach (var token in tokens) { IncrementCount(token, counts); } return counts; } private void IncrementCount(int tokenId, Dictionary<int, int> counts) { if (counts.TryGetValue(tokenId, out var existing)) { counts[tokenId] = existing + 1; } else { counts[tokenId] = 1; } } private void DecrementCount(int tokenId, Dictionary<int, int> counts) { if (!counts.TryGetValue(tokenId, out var existing)) { return; } if (existing <= 1) { counts.Remove(tokenId); } else { counts[tokenId] = existing - 1; } } private int SampleNextToken(float[] logits, int offset, int length, Dictionary<int, int> occurrence) { if (SamplingTemperature <= 0.0f || SamplingTopK <= 1) { return ArgMax(logits, offset, length); } var topCandidates = new List<(int token, float logit)>(SamplingTopK); for (var i = 0; i < length; i++) { if ((_eosTokenId >= 0 && i == _eosTokenId) || (_padTokenId >= 0 && i == _padTokenId)) { continue; } var adjustedLogit = logits[offset + i]; if (occurrence.TryGetValue(i, out var count) && count > 0) { adjustedLogit /= RepetitionPenalty; } adjustedLogit /= SamplingTemperature; if (topCandidates.Count < SamplingTopK) { topCandidates.Add((i, adjustedLogit)); continue; } var minIndex = 0; var minLogit = topCandidates[0].logit; for (var k = 1; k < topCandidates.Count; k++) { if (topCandidates[k].logit < minLogit) { minLogit = topCandidates[k].logit; minIndex = k; } } if (adjustedLogit > minLogit) { topCandidates[minIndex] = (i, adjustedLogit); } } if (topCandidates.Count == 0) { return ArgMax(logits, offset, length); } float maxLogit = float.NegativeInfinity; foreach (var candidate in topCandidates) { if (candidate.logit > maxLogit) { maxLogit = candidate.logit; } } var probabilities = new List<(int token, float probability)>(topCandidates.Count); var total = 0f; foreach (var candidate in topCandidates) { var exp = Mathf.Exp(candidate.logit - maxLogit); total += exp; probabilities.Add((candidate.token, exp)); } if (total <= 0f) { return topCandidates[0].token; } var sample = (float)_random.NextDouble() * total; var cumulative = 0f; foreach (var candidate in probabilities) { cumulative += candidate.probability; if (sample <= cumulative) { return candidate.token; } } return probabilities[probabilities.Count - 1].token; } private static int ArgMax(float[] data, int offset, int length) { var maxIndex = 0; var maxValue = float.NegativeInfinity; for (var i = 0; i < length; i++) { var value = data[offset + i]; if (value > maxValue) { maxValue = value; maxIndex = i; } } return maxIndex; } private static void TrimInactive(List<int> ids, List<int> attention) { if (attention.Count != ids.Count) { return; } var lastActive = attention.Count - 1; while (lastActive >= 0 && attention[lastActive] == 0) { lastActive--; } if (lastActive < 0) { ids.Clear(); attention.Clear(); return; } var removeCount = ids.Count - (lastActive + 1); if (removeCount > 0) { ids.RemoveRange(lastActive + 1, removeCount); attention.RemoveRange(lastActive + 1, removeCount); } } private TokenizerLoadResult BuildTokenizer(string tokenizerJson) { var config = JObject.Parse(tokenizerJson); var metadataResult = new TokenizerMetadata(); var model = config["model"] as JObject ?? throw new InvalidOperationException("Tokenizer config missing model section."); var modelType = model["type"]?.Value<string>() ?? string.Empty; if (string.Equals(modelType, "unigram", StringComparison.OrdinalIgnoreCase)) { return BuildSentencePieceTokenizer(config, model, metadataResult); } return BuildBpeTokenizer(config, model, metadataResult); } private TokenizerLoadResult BuildBpeTokenizer(JObject config, JObject model, TokenizerMetadata metadata) { var vocab = model["vocab"] as JObject ?? throw new InvalidOperationException("Tokenizer config missing vocab section."); var mergesToken = model["merges"] as JArray; if (mergesToken == null) { mergesToken = config["merges"] as JArray; } var vocabulary = new Dictionary<string, int>(vocab.Count); foreach (var property in vocab.Properties()) { var token = property.Name; var id = property.Value.Value<int>(); vocabulary[token] = id; } var merges = new List<MergePair>(); if (mergesToken != null) { foreach (var merge in mergesToken) { var parts = merge.Value<string>()?.Split(' '); if (parts != null && parts.Length == 2) { merges.Add(new MergePair(parts[0], parts[1])); } } } var addedTokens = BuildAddedTokens(config, metadata); metadata.PadTokenId = config["pad_token_id"]?.Value<int?>(); metadata.BosTokenId = config["bos_token_id"]?.Value<int?>(); metadata.EosTokenId = config["eos_token_id"]?.Value<int?>(); metadata.UnknownTokenId = config["unk_token_id"]?.Value<int?>(); var mapperOptions = new BpeMapperOptions { UnknownToken = model["unk_token"]?.Value<string>(), SubWordPrefix = model["continuing_subword_prefix"]?.Value<string>(), WordSuffix = model["end_of_word_suffix"]?.Value<string>(), ByteFallback = model["byte_fallback"]?.Value<bool?>() }; var mapper = new BpeMapper(vocabulary, merges, mapperOptions); var preTokenizerConfig = config["pre_tokenizer"]?["add_prefix_space"]?.Value<bool?>(); var addPrefixSpace = preTokenizerConfig ?? true; var preTokenizer = new ByteLevelPreTokenizer(addPrefixSpace: addPrefixSpace); var decoder = new ByteLevelDecoder(); var tokenizer = new Tokenizer( mapper, normalizer: null, preTokenizer: preTokenizer, truncator: null, postProcessor: null, paddingProcessor: null, decoder: decoder, addedVocabulary: addedTokens.Configurations); var adapter = new SentisBpeTokenizerAdapter(tokenizer, metadata.SpecialTokenIds); return new TokenizerLoadResult(adapter, metadata); } private TokenizerLoadResult BuildSentencePieceTokenizer(JObject config, JObject model, TokenizerMetadata metadata) { var vocabArray = model["vocab"] as JArray ?? throw new InvalidOperationException("SentencePiece tokenizer missing vocab array."); var idToToken = new List<string>(vocabArray.Count); var tokenToId = new Dictionary<string, int>(StringComparer.Ordinal); for (var i = 0; i < vocabArray.Count; i++) { var entry = vocabArray[i] as JArray ?? throw new InvalidOperationException("SentencePiece vocab entry is invalid."); if (entry.Count < 1) { continue; } var token = entry[0].Value<string>() ?? string.Empty; EnsureCapacity(idToToken, i); idToToken[i] = token; tokenToId[token] = i; } var addedTokens = BuildAddedTokens(config, metadata); foreach (var info in addedTokens.Infos) { EnsureCapacity(idToToken, info.Id); idToToken[info.Id] = info.Value; tokenToId[info.Value] = info.Id; } metadata.PadTokenId = config["pad_token_id"]?.Value<int?>(); metadata.BosTokenId = config["bos_token_id"]?.Value<int?>(); metadata.EosTokenId = config["eos_token_id"]?.Value<int?>(); metadata.UnknownTokenId = config["unk_token_id"]?.Value<int?>() ?? model["unk_id"]?.Value<int?>(); var adapter = new SentencePieceTokenizerAdapter( idToToken.ToArray(), tokenToId, metadata, metadata.UnknownTokenId); return new TokenizerLoadResult(adapter, metadata); } private static AddedTokenBuildResult BuildAddedTokens(JObject config, TokenizerMetadata metadata) { var addedTokensToken = config["added_tokens"] as JArray; if (addedTokensToken == null) { return new AddedTokenBuildResult(Array.Empty<TokenConfiguration>(), Array.Empty<AddedTokenInfo>()); } var tokens = new List<TokenConfiguration>(); var infos = new List<AddedTokenInfo>(); foreach (var token in addedTokensToken.OfType<JObject>()) { var id = token["id"].Value<int>(); var value = token["content"].Value<string>(); var singleWord = token["single_word"]?.Value<bool>() ?? false; var lstrip = token["lstrip"]?.Value<bool>() ?? false; var rstrip = token["rstrip"]?.Value<bool>() ?? false; var normalized = token["normalized"]?.Value<bool>() ?? false; var special = token["special"]?.Value<bool>() ?? false; var stripDirection = Direction.None; if (lstrip) { stripDirection |= Direction.Left; } if (rstrip) { stripDirection |= Direction.Right; } tokens.Add(new TokenConfiguration(id, value, singleWord, stripDirection, normalized, special)); infos.Add(new AddedTokenInfo(id, value, special)); if (special) { metadata.SpecialTokenByValue[value] = id; metadata.SpecialTokenIds.Add(id); } } return new AddedTokenBuildResult(tokens, infos); } private class TokenizerMetadata { public int? PadTokenId; public int? BosTokenId; public int? EosTokenId; public int? UnknownTokenId; public Dictionary<string, int> SpecialTokenByValue { get; } = new Dictionary<string, int>(StringComparer.Ordinal); public HashSet<int> SpecialTokenIds { get; } = new HashSet<int>(); } private interface ITokenizerBridge { TokenEncoding Encode(string text, bool addSpecialTokens); string Decode(IReadOnlyList<int> tokens, bool skipSpecialTokens); } private readonly struct TokenEncoding { public TokenEncoding(IReadOnlyList<int> ids, IReadOnlyList<int> attention) { Ids = ids; Attention = attention; } public IReadOnlyList<int> Ids { get; } public IReadOnlyList<int> Attention { get; } } private readonly struct TokenizerLoadResult { public TokenizerLoadResult(ITokenizerBridge tokenizer, TokenizerMetadata metadata) { Tokenizer = tokenizer; Metadata = metadata; } public ITokenizerBridge Tokenizer { get; } public TokenizerMetadata Metadata { get; } } private readonly struct AddedTokenBuildResult { public AddedTokenBuildResult(IReadOnlyList<TokenConfiguration> configurations, IReadOnlyList<AddedTokenInfo> infos) { Configurations = configurations; Infos = infos; } public IReadOnlyList<TokenConfiguration> Configurations { get; } public IReadOnlyList<AddedTokenInfo> Infos { get; } } private readonly struct AddedTokenInfo { public AddedTokenInfo(int id, string value, bool special) { Id = id; Value = value; Special = special; } public int Id { get; } public string Value { get; } public bool Special { get; } } private sealed class SentisBpeTokenizerAdapter : ITokenizerBridge { private readonly Tokenizer _tokenizer; private readonly HashSet<int> _specialTokenIds; public SentisBpeTokenizerAdapter(Tokenizer tokenizer, HashSet<int> specialTokenIds) { _tokenizer = tokenizer; _specialTokenIds = specialTokenIds ?? new HashSet<int>(); } public TokenEncoding Encode(string text, bool addSpecialTokens) { var encoding = _tokenizer.Encode(text, addSpecialTokens: addSpecialTokens); var ids = new List<int>(); var attention = new List<int>(); foreach (var sequence in encoding.GetEncodings()) { ids.AddRange(sequence.GetIds()); attention.AddRange(sequence.GetAttentionMask()); } return new TokenEncoding(ids, attention); } public string Decode(IReadOnlyList<int> tokens, bool skipSpecialTokens) { if (!skipSpecialTokens) { return _tokenizer.Decode(tokens, skipSpecialTokens: false); } var filtered = tokens.Where(t => !_specialTokenIds.Contains(t)).ToList(); return _tokenizer.Decode(filtered, skipSpecialTokens: false); } } private sealed class SentencePieceTokenizerAdapter : ITokenizerBridge { private readonly string[] _idToToken; private readonly Dictionary<string, int> _tokenToId; private readonly HashSet<int> _specialTokenIds; private readonly int? _bosTokenId; private readonly int? _eosTokenId; private readonly int? _unknownTokenId; private readonly int _maxTokenLength; public SentencePieceTokenizerAdapter( string[] idToToken, Dictionary<string, int> tokenToId, TokenizerMetadata metadata, int? unknownTokenId) { _idToToken = idToToken; _tokenToId = tokenToId; _specialTokenIds = metadata.SpecialTokenIds; _bosTokenId = metadata.BosTokenId ?? (metadata.SpecialTokenByValue.TryGetValue("<s>", out var bos) ? bos : (int?)null); _eosTokenId = metadata.EosTokenId ?? (metadata.SpecialTokenByValue.TryGetValue("</s>", out var eos) ? eos : (int?)null); _unknownTokenId = unknownTokenId ?? (metadata.SpecialTokenByValue.TryGetValue("<unk>", out var unk) ? unk : (int?)null); _maxTokenLength = idToToken.Max(t => t?.Length ?? 0); } public TokenEncoding Encode(string text, bool addSpecialTokens) { var ids = new List<int>(); if (addSpecialTokens && _bosTokenId.HasValue) { ids.Add(_bosTokenId.Value); } var normalized = Normalize(text); var pieces = Segment(normalized); ids.AddRange(pieces); if (addSpecialTokens && _eosTokenId.HasValue) { ids.Add(_eosTokenId.Value); } var attention = Enumerable.Repeat(1, ids.Count).ToList(); return new TokenEncoding(ids, attention); } public string Decode(IReadOnlyList<int> tokens, bool skipSpecialTokens) { var builder = new StringBuilder(); foreach (var tokenId in tokens) { if (tokenId < 0 || tokenId >= _idToToken.Length) { continue; } if (skipSpecialTokens && _specialTokenIds.Contains(tokenId)) { continue; } var token = _idToToken[tokenId]; if (string.IsNullOrEmpty(token)) { continue; } if (token == "▁") { if (builder.Length > 0) { builder.Append(' '); } continue; } if (token[0] == '▁') { if (builder.Length > 0) { builder.Append(' '); } builder.Append(token.Substring(1)); } else { builder.Append(token); } } return builder.ToString().Trim(); } private List<int> Segment(string normalized) { var results = new List<int>(); if (string.IsNullOrEmpty(normalized)) { return results; } var length = normalized.Length; var index = 0; while (index < length) { var matched = false; var maxLength = Math.Min(_maxTokenLength, length - index); for (var span = maxLength; span > 0; span--) { var candidate = normalized.Substring(index, span); if (_tokenToId.TryGetValue(candidate, out var tokenId)) { results.Add(tokenId); index += span; matched = true; break; } } if (!matched) { if (_unknownTokenId.HasValue) { results.Add(_unknownTokenId.Value); } index++; } } return results; } private static string Normalize(string text) { if (string.IsNullOrEmpty(text)) { return string.Empty; } var builder = new StringBuilder(); var newWord = true; foreach (var ch in text) { if (char.IsWhiteSpace(ch)) { newWord = true; continue; } if (newWord) { builder.Append('▁'); newWord = false; } builder.Append(ch); } return builder.ToString(); } } private static void EnsureCapacity(List<string> list, int index) { while (list.Count <= index) { list.Add(string.Empty); } } } }
スクリプトをシーンに配置し、シーンを再生します。

コメントを入力すると推論結果が返ってきました。
