Azzera filtri
Azzera filtri

GPU coder Auto code generation for DAGNetwork Fails

1 visualizzazione (ultimi 30 giorni)
Oualid
Oualid il 2 Ago 2022
Modificato: Oualid il 4 Ago 2022
Dear cummunity,
I am trying to generate C/C++ code from an DAGNetwork, which is imported from an YOLOv7 ONNX model. When I try to run this model in Simulink Using the Predict block, no errrors occour everything works. I saved the DAGNetwork as *.mat file. Loading and running the model again is working well.
Next i wan to generate C/C++ code from the model via Simulink, I get the error message: ??? Layer hyper-parameters for custom layer 'concatenationLayer' must be numeric scalar, scalar logical, character or string array, or a matrix of type double or single.
Once remove this layer "concatenationLayer", the code generation works perfectly .
I undestand that the concatenationLayer is a custom layer, But I am wondering how can I get ride of this error ?
  4 Commenti
Sergio Matiz Romero
Sergio Matiz Romero il 3 Ago 2022
Hi Oualid,
Thank you for sharing additional details on the issue you are facing. Based on the code you shared, the custom layer property ONNXParams in Reshape_To_ConcatLayer1211 is a likely a structure:
[output, outputNumDims, state] = Reshape_To_ConcatGraph1200(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, NumDims.onnx__Reshape_488, NumDims.onnx__Reshape_527, NumDims.onnx__Reshape_566, Vars, NumDims, Training, params.State);
Structure types are unsupported as layer properties for deep learning code generation based on:
Therefore, my recomendation in this case would be to modify the code such that the structure fields are stored as individual properties in the Reshape_To_ConcatLayer constructor. An example of what this would look like is provided below
classdef Reshape_To_ConcatLayer1211 < nnet.layer.Layer & nnet.layer.Formattable
properties
OnnxParam1
OnnxParam2
OnnxParam3
end
methods
function this = Reshape_To_ConcatLayer1211(name, onnxParams)
this.Name = name;
this.NumInputs = 3;
this.OutputNames = {'output'};
% Below I use generic names for the different properties and
% store them individually, avoiding the use of structures
OnnxParam1 = onnxParams.Param1;
OnnxParam2 = onnxParams.Param2;
OnnxParam3 = onnxParams.Param3;
end
Then you would have to modify the code such that it picks up the individual properties, instead of using the structure ONNXParams.
I hope this helps you solve the issue you are facing
Oualid
Oualid il 4 Ago 2022
Modificato: Oualid il 4 Ago 2022
Hi Sergio,
Thank you for your assistance. I changed the code as you suggested but still have the same issue. Here are the steps I followed :
- 1 Load the ONNX model as DAGnetwork.
net = importONNXNetwork(modelfile,OutputDataFormats="TBC")
-2 Created my own Graph :
Lgraph = net.layerGraph
-3 Extract the ONNXparameters of the Layer causing the problem ( Reshape_To_ConcatLayer1211 )
Param = Lgraph.Layers(300,1).ONNXParams
4- Create New Layer Named (Reshape_To_ConcatLayer1211) and pass the 5 Parameters one by one : I also changed the Autogenerated Layer to accept 5 params input in the way you suggested .
layer = Reshape_To_ConcatLayer1211('Reshape_To_ConcatFcn',Param.Learnables,Param.Nonlearnables,Param.State,Param.NumDimensions,Param.NetworkFunctionName);
5- Replace The Old layer in Lgraph and create new graph named mygraph :
mygraph= replaceLayer(Lgraph,'Reshape_To_ConcatLayer1211',layer,'ReconnectBy','order')
6-Create My own Network
mynet = assembleNetwork(mygraph)
7- Save it as .mat
8- Load to Simulink and Run sim --->>> everything works perfectly
9- Auto code Gen faills with the exactly same Error .
The modified code is below :
classdef Reshape_To_ConcatLayer1211 < nnet.layer.Layer & nnet.layer.Formattable
% A custom layer auto-generated while importing an ONNX network.
%#codegen
%#ok<*PROPLC>
%#ok<*NBRAK>
%#ok<*INUSL>
%#ok<*VARARG>
properties (Learnable)
end
properties
OnnxParam1
OnnxParam2
OnnxParam3
OnnxParam4
OnnxParam5
end
methods
function this = Reshape_To_ConcatLayer1211(name, onnxParam1,onnxParam2,onnxParam3,onnxParam4,onnxParam5)
this.Name = name;
this.NumInputs = 3;
this.OutputNames = {'output'};
this.Description = 'output';
this.Type = 'Relu';
this.OnnxParam1 = onnxParam1;
this.OnnxParam2 = onnxParam2;
this.OnnxParam3 = onnxParam3;
this.OnnxParam4 = onnxParam4;
this.OnnxParam5 = onnxParam5;
end
function [output] = predict(this, onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566)
if isdlarray(onnx__Reshape_488)
onnx__Reshape_488 = stripdims(onnx__Reshape_488);
end
if isdlarray(onnx__Reshape_527)
onnx__Reshape_527 = stripdims(onnx__Reshape_527);
end
if isdlarray(onnx__Reshape_566)
onnx__Reshape_566 = stripdims(onnx__Reshape_566);
end
onnx__Reshape_488NumDims = 4;
onnx__Reshape_527NumDims = 4;
onnx__Reshape_566NumDims = 4;
Param1=this.OnnxParam1;
Param2= this.OnnxParam2;
Param3= this.OnnxParam3;
Param4= this.OnnxParam4;
Param5= this.OnnxParam5;
[output, outputNumDims] = Reshape_To_ConcatFcn(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, onnx__Reshape_488NumDims, onnx__Reshape_527NumDims, onnx__Reshape_566NumDims,Param1,Param2,Param3,Param4,Param5, 'Training', false, ...
'InputDataPermutation', {[4 3 1 2], [4 3 1 2], [4 3 1 2], ['as-is'], ['as-is'], ['as-is']}, ...
'OutputDataPermutation', {[3 2 1], ['as-is']});
if any(cellfun(@(A)isempty(A)||~isnumeric(A), {output}))
fprintf('Runtime error in network. The custom layer ''%s'' output an empty or non-numeric value.\n', 'Reshape_To_ConcatLayer1211');
error(message('nnet_cnn_onnx:onnx:BadCustomLayerRuntimeOutput', 'Reshape_To_ConcatLayer1211'));
end
output = dlarray(single(output), 'CBT');
if ~coder.target('MATLAB')
output = extractdata(output);
end
end
function [output] = forward(this, onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566)
if isdlarray(onnx__Reshape_488)
onnx__Reshape_488 = stripdims(onnx__Reshape_488);
end
if isdlarray(onnx__Reshape_527)
onnx__Reshape_527 = stripdims(onnx__Reshape_527);
end
if isdlarray(onnx__Reshape_566)
onnx__Reshape_566 = stripdims(onnx__Reshape_566);
end
onnx__Reshape_488NumDims = 4;
onnx__Reshape_527NumDims = 4;
onnx__Reshape_566NumDims = 4;
Param1 = this.OnnxParam1;
Param2= this.OnnxParam2;
Param3= this.OnnxParam3;
Param4= this.OnnxParam4;
Param5= this.OnnxParam5;
[output, outputNumDims] = Reshape_To_ConcatFcn(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, onnx__Reshape_488NumDims, onnx__Reshape_527NumDims, onnx__Reshape_566NumDims,Param1,Param2,Param3,Param4,Param5, 'Training', true, ...
'InputDataPermutation', {[4 3 1 2], [4 3 1 2], [4 3 1 2], ['as-is'], ['as-is'], ['as-is']}, ...
'OutputDataPermutation', {[3 2 1], ['as-is']});
if any(cellfun(@(A)isempty(A)||~isnumeric(A), {output}))
fprintf('Runtime error in network. The custom layer ''%s'' output an empty or non-numeric value.\n', 'Reshape_To_ConcatLayer1211');
error(message('nnet_cnn_onnx:onnx:BadCustomLayerRuntimeOutput', 'Reshape_To_ConcatLayer1211'));
end
output = dlarray(single(output), 'CBT');
if ~coder.target('MATLAB')
output = extractdata(output);
end
end
end
end
function [output, outputNumDims, state] = Reshape_To_ConcatFcn(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, onnx__Reshape_488NumDims, onnx__Reshape_527NumDims, onnx__Reshape_566NumDims, Param1,Param2,Param3,Param4,Param5,varargin)
% Preprocess the input data and arguments:
[onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, Training, outputDataPerms, anyDlarrayInputs] = preprocessInput(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, varargin{:});
% Put all variables into a single struct to implement dynamic scoping:
[Vars, NumDims] = packageVariables(Param1,Param2,Param3,Param4, {'onnx__Reshape_488', 'onnx__Reshape_527', 'onnx__Reshape_566'}, {onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566}, [onnx__Reshape_488NumDims onnx__Reshape_527NumDims onnx__Reshape_566NumDims]);
% Call the top-level graph function:
[output, outputNumDims, state] = Reshape_To_ConcatGraph1200(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, NumDims.onnx__Reshape_488, NumDims.onnx__Reshape_527, NumDims.onnx__Reshape_566, Vars, NumDims, Training, Param3);
% Postprocess the output data
[output] = postprocessOutput(output, outputDataPerms, anyDlarrayInputs, Training, varargin{:});
end
function [output, outputNumDims1210, state] = Reshape_To_ConcatGraph1200(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, onnx__Reshape_488NumDims1207, onnx__Reshape_527NumDims1208, onnx__Reshape_566NumDims1209, Vars, NumDims, Training, state)
% Function implementing the graph 'Reshape_To_ConcatGraph1200'
% Update Vars and NumDims from the graph's formal input parameters. Note that state variables are already in Vars.
Vars.onnx__Reshape_488 = onnx__Reshape_488;
NumDims.onnx__Reshape_488 = onnx__Reshape_488NumDims1207;
Vars.onnx__Reshape_527 = onnx__Reshape_527;
NumDims.onnx__Reshape_527 = onnx__Reshape_527NumDims1208;
Vars.onnx__Reshape_566 = onnx__Reshape_566;
NumDims.onnx__Reshape_566 = onnx__Reshape_566NumDims1209;
% Execute the operators:
% Reshape:
[shape, NumDims.onnx__Transpose_500] = prepareReshapeArgs(Vars.onnx__Reshape_488, Vars.onnx__Reshape_613, NumDims.onnx__Reshape_488, 0);
Vars.onnx__Transpose_500 = reshape(Vars.onnx__Reshape_488, shape{:});
% Transpose:
[perm, NumDims.onnx__Sigmoid_501] = prepareTransposeArgs(Vars.TransposePerm1201, NumDims.onnx__Transpose_500);
if ~isempty(perm)
Vars.onnx__Sigmoid_501 = permute(Vars.onnx__Transpose_500, perm);
end
% Sigmoid:
Vars.y = sigmoid(Vars.onnx__Sigmoid_501);
NumDims.y = NumDims.onnx__Sigmoid_501;
% Split:
[Vars.onnx__Mul_503, Vars.onnx__Mul_504, Vars.onnx__Concat_505, NumDims.onnx__Mul_503, NumDims.onnx__Mul_504, NumDims.onnx__Concat_505] = onnxSplit(Vars.y, 4, Vars.SplitSplit1202, 0, NumDims.y);
% Mul:
Vars.onnx__Add_507 = Vars.onnx__Mul_503 .* Vars.onnx__Pow_614;
NumDims.onnx__Add_507 = max(NumDims.onnx__Mul_503, NumDims.onnx__Pow_614);
% Add:
Vars.onnx__Mul_509 = Vars.onnx__Add_507 + Vars.onnx__Add_508;
NumDims.onnx__Mul_509 = max(NumDims.onnx__Add_507, NumDims.onnx__Add_508);
% Mul:
Vars.onnx__Concat_511 = Vars.onnx__Mul_509 .* Vars.onnx__Mul_510;
NumDims.onnx__Concat_511 = max(NumDims.onnx__Mul_509, NumDims.onnx__Mul_510);
% Mul:
Vars.onnx__Pow_513 = Vars.onnx__Mul_504 .* Vars.onnx__Pow_614;
NumDims.onnx__Pow_513 = max(NumDims.onnx__Mul_504, NumDims.onnx__Pow_614);
% Pow:
Vars.onnx__Mul_516 = power(Vars.onnx__Pow_513, Vars.onnx__Pow_614);
NumDims.onnx__Mul_516 = max(NumDims.onnx__Pow_513, NumDims.onnx__Pow_614);
% Mul:
Vars.onnx__Concat_518 = Vars.onnx__Mul_516 .* Vars.onnx__Mul_517;
NumDims.onnx__Concat_518 = max(NumDims.onnx__Mul_516, NumDims.onnx__Mul_517);
% Concat:
[Vars.onnx__Reshape_519, NumDims.onnx__Reshape_519] = onnxConcat(4, {Vars.onnx__Concat_511, Vars.onnx__Concat_518, Vars.onnx__Concat_505}, [NumDims.onnx__Concat_511, NumDims.onnx__Concat_518, NumDims.onnx__Concat_505]);
% Reshape:
[shape, NumDims.onnx__Concat_526] = prepareReshapeArgs(Vars.onnx__Reshape_519, Vars.onnx__Reshape_618, NumDims.onnx__Reshape_519, 0);
Vars.onnx__Concat_526 = reshape(Vars.onnx__Reshape_519, shape{:});
% Reshape:
[shape, NumDims.onnx__Transpose_539] = prepareReshapeArgs(Vars.onnx__Reshape_527, Vars.onnx__Reshape_624, NumDims.onnx__Reshape_527, 0);
Vars.onnx__Transpose_539 = reshape(Vars.onnx__Reshape_527, shape{:});
% Transpose:
[perm, NumDims.onnx__Sigmoid_540] = prepareTransposeArgs(Vars.TransposePerm1203, NumDims.onnx__Transpose_539);
if ~isempty(perm)
Vars.onnx__Sigmoid_540 = permute(Vars.onnx__Transpose_539, perm);
end
% Sigmoid:
Vars.y_3 = sigmoid(Vars.onnx__Sigmoid_540);
NumDims.y_3 = NumDims.onnx__Sigmoid_540;
% Split:
[Vars.onnx__Mul_542, Vars.onnx__Mul_543, Vars.onnx__Concat_544, NumDims.onnx__Mul_542, NumDims.onnx__Mul_543, NumDims.onnx__Concat_544] = onnxSplit(Vars.y_3, 4, Vars.SplitSplit1204, 0, NumDims.y_3);
% Mul:
Vars.onnx__Add_546 = Vars.onnx__Mul_542 .* Vars.onnx__Pow_614;
NumDims.onnx__Add_546 = max(NumDims.onnx__Mul_542, NumDims.onnx__Pow_614);
% Add:
Vars.onnx__Mul_548 = Vars.onnx__Add_546 + Vars.onnx__Add_547;
NumDims.onnx__Mul_548 = max(NumDims.onnx__Add_546, NumDims.onnx__Add_547);
% Mul:
Vars.onnx__Concat_550 = Vars.onnx__Mul_548 .* Vars.onnx__Mul_549;
NumDims.onnx__Concat_550 = max(NumDims.onnx__Mul_548, NumDims.onnx__Mul_549);
% Mul:
Vars.onnx__Pow_552 = Vars.onnx__Mul_543 .* Vars.onnx__Pow_614;
NumDims.onnx__Pow_552 = max(NumDims.onnx__Mul_543, NumDims.onnx__Pow_614);
% Pow:
Vars.onnx__Mul_555 = power(Vars.onnx__Pow_552, Vars.onnx__Pow_614);
NumDims.onnx__Mul_555 = max(NumDims.onnx__Pow_552, NumDims.onnx__Pow_614);
% Mul:
Vars.onnx__Concat_557 = Vars.onnx__Mul_555 .* Vars.onnx__Mul_556;
NumDims.onnx__Concat_557 = max(NumDims.onnx__Mul_555, NumDims.onnx__Mul_556);
% Concat:
[Vars.onnx__Reshape_558, NumDims.onnx__Reshape_558] = onnxConcat(4, {Vars.onnx__Concat_550, Vars.onnx__Concat_557, Vars.onnx__Concat_544}, [NumDims.onnx__Concat_550, NumDims.onnx__Concat_557, NumDims.onnx__Concat_544]);
% Reshape:
[shape, NumDims.onnx__Concat_565] = prepareReshapeArgs(Vars.onnx__Reshape_558, Vars.onnx__Reshape_618, NumDims.onnx__Reshape_558, 0);
Vars.onnx__Concat_565 = reshape(Vars.onnx__Reshape_558, shape{:});
% Reshape:
[shape, NumDims.onnx__Transpose_578] = prepareReshapeArgs(Vars.onnx__Reshape_566, Vars.onnx__Reshape_635, NumDims.onnx__Reshape_566, 0);
Vars.onnx__Transpose_578 = reshape(Vars.onnx__Reshape_566, shape{:});
% Transpose:
[perm, NumDims.onnx__Sigmoid_579] = prepareTransposeArgs(Vars.TransposePerm1205, NumDims.onnx__Transpose_578);
if ~isempty(perm)
Vars.onnx__Sigmoid_579 = permute(Vars.onnx__Transpose_578, perm);
end
% Sigmoid:
Vars.y_7 = sigmoid(Vars.onnx__Sigmoid_579);
NumDims.y_7 = NumDims.onnx__Sigmoid_579;
% Split:
[Vars.onnx__Mul_581, Vars.onnx__Mul_582, Vars.onnx__Concat_583, NumDims.onnx__Mul_581, NumDims.onnx__Mul_582, NumDims.onnx__Concat_583] = onnxSplit(Vars.y_7, 4, Vars.SplitSplit1206, 0, NumDims.y_7);
% Mul:
Vars.onnx__Add_585 = Vars.onnx__Mul_581 .* Vars.onnx__Pow_614;
NumDims.onnx__Add_585 = max(NumDims.onnx__Mul_581, NumDims.onnx__Pow_614);
% Add:
Vars.onnx__Mul_587 = Vars.onnx__Add_585 + Vars.onnx__Add_586;
NumDims.onnx__Mul_587 = max(NumDims.onnx__Add_585, NumDims.onnx__Add_586);
% Mul:
Vars.onnx__Concat_589 = Vars.onnx__Mul_587 .* Vars.onnx__Mul_588;
NumDims.onnx__Concat_589 = max(NumDims.onnx__Mul_587, NumDims.onnx__Mul_588);
% Mul:
Vars.onnx__Pow_591 = Vars.onnx__Mul_582 .* Vars.onnx__Pow_614;
NumDims.onnx__Pow_591 = max(NumDims.onnx__Mul_582, NumDims.onnx__Pow_614);
% Pow:
Vars.onnx__Mul_594 = power(Vars.onnx__Pow_591, Vars.onnx__Pow_614);
NumDims.onnx__Mul_594 = max(NumDims.onnx__Pow_591, NumDims.onnx__Pow_614);
% Mul:
Vars.onnx__Concat_596 = Vars.onnx__Mul_594 .* Vars.onnx__Mul_595;
NumDims.onnx__Concat_596 = max(NumDims.onnx__Mul_594, NumDims.onnx__Mul_595);
% Concat:
[Vars.onnx__Reshape_597, NumDims.onnx__Reshape_597] = onnxConcat(4, {Vars.onnx__Concat_589, Vars.onnx__Concat_596, Vars.onnx__Concat_583}, [NumDims.onnx__Concat_589, NumDims.onnx__Concat_596, NumDims.onnx__Concat_583]);
% Reshape:
[shape, NumDims.onnx__Concat_604] = prepareReshapeArgs(Vars.onnx__Reshape_597, Vars.onnx__Reshape_618, NumDims.onnx__Reshape_597, 0);
Vars.onnx__Concat_604 = reshape(Vars.onnx__Reshape_597, shape{:});
% Concat:
[Vars.output, NumDims.output] = onnxConcat(1, {Vars.onnx__Concat_526, Vars.onnx__Concat_565, Vars.onnx__Concat_604}, [NumDims.onnx__Concat_526, NumDims.onnx__Concat_565, NumDims.onnx__Concat_604]);
% Set graph output arguments from Vars and NumDims:
output = Vars.output;
outputNumDims1210 = NumDims.output;
% Set output state from Vars:
state = updateStruct(state, Vars);
end
function [inputDataPerms, outputDataPerms, Training] = parseInputs(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, numDataOutputs, varargin)
% Function to validate inputs to Reshape_To_ConcatFcn:
p = inputParser;
isValidArrayInput = @(x)isnumeric(x) || isstring(x);
addRequired(p, 'onnx__Reshape_488', isValidArrayInput);
addRequired(p, 'onnx__Reshape_527', isValidArrayInput);
addRequired(p, 'onnx__Reshape_566', isValidArrayInput);
addParameter(p, 'InputDataPermutation', 'auto');
addParameter(p, 'OutputDataPermutation', 'auto');
addParameter(p, 'Training', false);
parse(p, onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, varargin{:});
inputDataPerms = p.Results.InputDataPermutation;
outputDataPerms = p.Results.OutputDataPermutation;
Training = p.Results.Training;
if isnumeric(inputDataPerms)
inputDataPerms = {inputDataPerms};
end
if isstring(inputDataPerms) && isscalar(inputDataPerms) || ischar(inputDataPerms)
inputDataPerms = repmat({inputDataPerms},1,3);
end
if isnumeric(outputDataPerms)
outputDataPerms = {outputDataPerms};
end
if isstring(outputDataPerms) && isscalar(outputDataPerms) || ischar(outputDataPerms)
outputDataPerms = repmat({outputDataPerms},1,numDataOutputs);
end
end
function [onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, Training, outputDataPerms, anyDlarrayInputs] = preprocessInput(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, varargin)
% Parse input arguments
[inputDataPerms, outputDataPerms, Training] = parseInputs(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, 1, varargin{:});
anyDlarrayInputs = any(cellfun(@(x)isa(x, 'dlarray'), {onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566}));
% Make the input variables into unlabelled dlarrays:
onnx__Reshape_488 = makeUnlabeledDlarray(onnx__Reshape_488);
onnx__Reshape_527 = makeUnlabeledDlarray(onnx__Reshape_527);
onnx__Reshape_566 = makeUnlabeledDlarray(onnx__Reshape_566);
% Permute inputs if requested:
onnx__Reshape_488 = permuteInputVar(onnx__Reshape_488, inputDataPerms{1}, 4);
onnx__Reshape_527 = permuteInputVar(onnx__Reshape_527, inputDataPerms{2}, 4);
onnx__Reshape_566 = permuteInputVar(onnx__Reshape_566, inputDataPerms{3}, 4);
end
function [output] = postprocessOutput(output, outputDataPerms, anyDlarrayInputs, Training, varargin)
% Set output type:
if ~anyDlarrayInputs && ~Training
if isdlarray(output)
output = extractdata(output);
end
end
% Permute outputs if requested:
output = permuteOutputVar(output, outputDataPerms{1}, 3);
end
%% dlarray functions implementing ONNX operators:
function [Y, numDimsY] = onnxConcat(ONNXAxis, XCell, numDimsXArray)
% Concatentation that treats all empties the same. Necessary because
% dlarray.cat does not allow, for example, cat(1, 1x1, 1x0) because the
% second dimension sizes do not match.
numDimsY = numDimsXArray(1);
XCell(cellfun(@isempty, XCell)) = [];
if isempty(XCell)
Y = dlarray([]);
else
if ONNXAxis<0
ONNXAxis = ONNXAxis + numDimsY;
end
DLTAxis = numDimsY - ONNXAxis;
Y = cat(DLTAxis, XCell{:});
end
end
function varargout = onnxSplit(X, ONNXaxis, splits, numSplits, numDimsX)
% Implements the ONNX Split operator
% ONNXaxis is origin 0. splits is a vector of the lengths of each segment.
% If numSplits is nonzero, instead split into segments of equal length.
if ONNXaxis<0
ONNXaxis = ONNXaxis + numDimsX;
end
DLTAxis = numDimsX - ONNXaxis;
if numSplits > 0
C = size(X, DLTAxis);
sz = floor(C/numSplits);
splits = repmat(sz, 1, numSplits);
else
splits = extractdata(splits);
end
S = struct;
S.type = '()';
S.subs = repmat({':'}, 1, ndims(X));
splitIndices = [0 cumsum(splits(:)')];
numY = numel(splitIndices)-1;
for i = 1:numY
from = splitIndices(i) + 1;
to = splitIndices(i+1);
S.subs{DLTAxis} = from:to;
% The first numY outputs are the Y's. The second numY outputs are their
% numDims. We assume all the outputs of Split have the same numDims as
% the input.
varargout{i} = subsref(X, S);
varargout{i + numY} = numDimsX;
end
end
function [DLTShape, numDimsY] = prepareReshapeArgs(X, ONNXShape, numDimsX, allowzero)
% Prepares arguments for implementing the ONNX Reshape operator
ONNXShape = flip(extractdata(ONNXShape)); % First flip the shape to make it correspond to the dimensions of X.
% In ONNX, 0 means "unchanged" if allowzero is false, and -1 means "infer". In DLT, there is no
% "unchanged", and [] means "infer".
DLTShape = num2cell(ONNXShape); % Make a cell array so we can include [].
% Replace zeros with the actual size if allowzero is true
if any(ONNXShape==0) && allowzero==0
i0 = find(ONNXShape==0);
DLTShape(i0) = num2cell(size(X, numDimsX - numel(ONNXShape) + i0)); % right-align the shape vector and dims
end
if any(ONNXShape == -1)
% Replace -1 with []
i = ONNXShape == -1;
DLTShape{i} = [];
end
if numel(DLTShape)==1
DLTShape = [DLTShape 1];
end
numDimsY = numel(ONNXShape);
end
function [perm, numDimsA] = prepareTransposeArgs(ONNXPerm, numDimsA)
% Prepares arguments for implementing the ONNX Transpose operator
if numDimsA <= 1 % Tensors of numDims 0 or 1 are unchanged by ONNX Transpose.
perm = [];
else
if isempty(ONNXPerm) % Empty ONNXPerm means reverse the dimensions.
perm = numDimsA:-1:1;
else
perm = numDimsA-flip(ONNXPerm);
end
end
end
%% Utility functions:
function s = appendStructs(varargin)
% s = appendStructs(s1, s2,...). Assign all fields in s1, s2,... into s.
if isempty(varargin)
s = struct;
else
s = varargin{1};
for i = 2:numel(varargin)
fromstr = varargin{i};
fs = fieldnames(fromstr);
for j = 1:numel(fs)
s.(fs{j}) = fromstr.(fs{j});
end
end
end
end
function checkInputSize(inputShape, expectedShape, inputName)
if numel(expectedShape)==0
% The input is a scalar
if ~isequal(inputShape, [1 1])
inputSizeStr = makeSizeString(inputShape);
error(message('nnet_cnn_onnx:onnx:InputNeedsResize',inputName, "[1,1]", inputSizeStr));
end
elseif numel(expectedShape)==1
% The input is a vector
if ~shapeIsColumnVector(inputShape) || ~iSizesMatch({inputShape(1)}, expectedShape)
expectedShape{2} = 1;
expectedSizeStr = makeSizeString(expectedShape);
inputSizeStr = makeSizeString(inputShape);
error(message('nnet_cnn_onnx:onnx:InputNeedsResize',inputName, expectedSizeStr, inputSizeStr));
end
else
% The input has 2 dimensions or more
% The input dimensions have been reversed; flip them back to compare to the
% expected ONNX shape.
inputShape = fliplr(inputShape);
% If the expected shape has fewer dims than the input shape, error.
if numel(expectedShape) < numel(inputShape)
expectedSizeStr = strjoin(["[", strjoin(string(expectedShape), ","), "]"], "");
error(message('nnet_cnn_onnx:onnx:InputHasGreaterNDims', inputName, expectedSizeStr));
end
% Prepad the input shape with trailing ones up to the number of elements in
% expectedShape
inputShape = num2cell([ones(1, numel(expectedShape) - length(inputShape)) inputShape]);
% Find the number of variable size dimensions in the expected shape
numVariableInputs = sum(cellfun(@(x) isa(x, 'char') || isa(x, 'string'), expectedShape));
% Find the number of input dimensions that are not in the expected shape
% and cannot be represented by a variable dimension
nonMatchingInputDims = setdiff(string(inputShape), string(expectedShape));
numNonMatchingInputDims = numel(nonMatchingInputDims) - numVariableInputs;
expectedSizeStr = makeSizeString(expectedShape);
inputSizeStr = makeSizeString(inputShape);
if numNonMatchingInputDims == 0 && ~iSizesMatch(inputShape, expectedShape)
% The actual and expected input dimensions match, but in
% a different order. The input needs to be permuted.
error(message('nnet_cnn_onnx:onnx:InputNeedsPermute',inputName, expectedSizeStr, inputSizeStr));
elseif numNonMatchingInputDims > 0
% The actual and expected input sizes do not match.
error(message('nnet_cnn_onnx:onnx:InputNeedsResize',inputName, expectedSizeStr, inputSizeStr));
end
end
end
function doesMatch = iSizesMatch(inputShape, expectedShape)
% Check whether the input and expected shapes match, in order.
% Size elements match if (1) the elements are equal, or (2) the expected
% size element is a variable (represented by a character vector or string)
doesMatch = true;
for i=1:numel(inputShape)
if ~(isequal(inputShape{i},expectedShape{i}) || ischar(expectedShape{i}) || isstring(expectedShape{i}))
doesMatch = false;
return
end
end
end
function sizeStr = makeSizeString(shape)
sizeStr = strjoin(["[", strjoin(string(shape), ","), "]"], "");
end
function isVec = shapeIsColumnVector(shape)
if numel(shape) == 2 && shape(2) == 1
isVec = true;
else
isVec = false;
end
end
function X = makeUnlabeledDlarray(X)
% Make numeric X into an unlabelled dlarray
if isa(X, 'dlarray')
X = stripdims(X);
elseif isnumeric(X)
if isinteger(X)
% Make ints double so they can combine with anything without
% reducing precision
X = double(X);
end
X = dlarray(X);
end
end
function [Vars, NumDims] = packageVariables(Param1,Param2,Param3,Param4, inputNames, inputValues, inputNumDims)
% inputNames, inputValues are cell arrays. inputRanks is a numeric vector.
Vars = appendStructs(Param1, Param2,Param3);
NumDims = Param4;
% Add graph inputs
for i = 1:numel(inputNames)
Vars.(inputNames{i}) = inputValues{i};
NumDims.(inputNames{i}) = inputNumDims(i);
end
end
function X = permuteInputVar(X, userDataPerm, onnxNDims)
% Returns reverse-ONNX ordering
if onnxNDims == 0
return;
elseif onnxNDims == 1 && isvector(X)
X = X(:);
return;
elseif isnumeric(userDataPerm)
% Permute into reverse ONNX ordering
if numel(userDataPerm) ~= onnxNDims
error(message('nnet_cnn_onnx:onnx:InputPermutationSize', numel(userDataPerm), onnxNDims));
end
perm = fliplr(userDataPerm);
elseif isequal(userDataPerm, 'auto') && onnxNDims == 4
% Permute MATLAB HWCN to reverse onnx (WHCN)
perm = [2 1 3 4];
elseif isequal(userDataPerm, 'as-is')
% Do not permute the input
perm = 1:ndims(X);
else
% userDataPerm is either 'none' or 'auto' with no default, which means
% it's already in onnx ordering, so just make it reverse onnx
perm = max(2,onnxNDims):-1:1;
end
X = permute(X, perm);
end
function Y = permuteOutputVar(Y, userDataPerm, onnxNDims)
switch onnxNDims
case 0
perm = [];
case 1
if isnumeric(userDataPerm)
% Use the user's permutation because Y is a column vector which
% already matches ONNX.
perm = userDataPerm;
elseif isequal(userDataPerm, 'auto')
% Treat the 1D onnx vector as a 2D column and transpose it
perm = [2 1];
else
% userDataPerm is 'none'. Leave Y alone because it already
% matches onnx.
perm = [];
end
otherwise
% ndims >= 2
if isnumeric(userDataPerm)
% Use the inverse of the user's permutation. This is not just the
% flip of the permutation vector.
perm = onnxNDims + 1 - userDataPerm;
elseif isequal(userDataPerm, 'auto')
if onnxNDims == 2
% Permute reverse ONNX CN to DLT CN (do nothing)
perm = [];
elseif onnxNDims == 4
% Permute reverse onnx (WHCN) to MATLAB HWCN
perm = [2 1 3 4];
else
% User wants the output in ONNX ordering, so just reverse it from
% reverse onnx
perm = onnxNDims:-1:1;
end
elseif isequal(userDataPerm, 'as-is')
% Do not permute the input
perm = 1:ndims(Y);
else
% userDataPerm is 'none', so just make it reverse onnx
perm = onnxNDims:-1:1;
end
end
if ~isempty(perm)
Y = permute(Y, perm);
end
end
function s = updateStruct(s, t)
% Set all existing fields in s from fields in t, ignoring extra fields in t.
for name = transpose(fieldnames(s))
s.(name{1}) = t.(name{1});
end
end

Accedi per commentare.

Risposte (0)

Categorie

Scopri di più su Image Data Workflows in Help Center e File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by