HansonServo/nodegraph.cpp

354 lines
8.9 KiB
C++

#include "nodegraph.h"
#include "animation.h"
#include <cstring>
// CurveNode evaluation
void CurveNode::evaluate(uint32_t tick) {
if (!animation) {
outputValue = 2048;
return;
}
outputValue = animation->getMotorPosition(curveID, tick);
}
Node* NodeGraph::findNodeByID(uint8_t id) const {
for (Node* node : nodes) {
if (node->id == id) {
return node;
}
}
return nullptr; // not found
}
// ServoNode evaluation
void ServoNode::evaluate(uint32_t tick) {
outputValue = inputValue;
}
extern uint16_t getSineWaveValue(uint32_t tick);
// extern int faceDetectX();
// extern int faceDetectY();
void VariableNode::evaluate(uint32_t currentTick) {
switch (source) {
case VAR_SINE:
outputValue = getSineWaveValue(currentTick);
break;
case VAR_FACE_X:
//outputValue = faceDetectX();
break;
case VAR_FACE_Y:
//outputValue = faceDetectY();
break;
case VAR_ANALOG:
outputValue = analogRead(7); // or whichever pin
break;
}
}
void MathNode::evaluate(uint32_t tick) {
float input = static_cast<float>(inputValue);
float result = 0;
switch (op) {
case OP_MULTIPLY: result = input * value; break;
case OP_DIVIDE: result = value != 0 ? input / value : 0; break;
case OP_ADD: result = input + value; break;
case OP_SUBTRACT: result = input - value; break;
}
outputValue = static_cast<uint16_t>(result);
}
void MapNode::evaluate(uint32_t tick) {
int32_t input = inputValue;
if (inMax == inMin) {
outputValue = outMin;
return;
}
int32_t result = (input - inMin) * (outMax - outMin) / (inMax - inMin) + outMin;
outputValue = static_cast<uint16_t>(result);
}
// NodeGraph tick
void NodeGraph::tick(uint32_t currentTick, const Animation& animation) {
// Step 1: Evaluate each node and propagate outputs immediately
for (Node* node : nodes) {
node->evaluate(currentTick);
for (const NodeConnection& conn : connections) {
if (conn.fromID == node->id) {
Node* to = findNodeByID(conn.toID);
if (to) {
to->inputValue = node->outputValue; // ✅ generic propagation
}
}
}
}
}
std::vector<std::pair<uint8_t, uint16_t>> NodeGraph::getServoOutputs() const {
std::vector<std::pair<uint8_t, uint16_t>> outputs;
for (Node* node : nodes) {
if (node->type == TYPE_SERVONODE) {
const ServoNode* servo = static_cast<const ServoNode*>(node);
outputs.emplace_back(servo->motorID, servo->outputValue);
}
}
return outputs;
}
std::vector<Node*> NodeGraph::getSortedNodes() {
std::unordered_map<uint8_t, std::vector<uint8_t>> adj;
std::unordered_map<uint8_t, int> inDegree;
std::unordered_map<uint8_t, Node*> idMap;
for (Node* node : nodes) {
idMap[node->id] = node;
inDegree[node->id] = 0;
}
for (const NodeConnection& conn : connections) {
adj[conn.fromID].push_back(conn.toID);
inDegree[conn.toID]++;
}
std::vector<Node*> sorted;
std::queue<uint8_t> q;
for (const auto& [id, deg] : inDegree) {
if (deg == 0) q.push(id);
}
while (!q.empty()) {
uint8_t id = q.front();
q.pop();
sorted.push_back(idMap[id]);
for (uint8_t neighbor : adj[id]) {
if (--inDegree[neighbor] == 0) {
q.push(neighbor);
}
}
}
return sorted;
}
// This function links nodes to their outside inputs
void NodeGraph::bindAnimationContext(Animation* animation) {
for (Node* node : nodes) {
if (node->type == TYPE_CURVENODE) {
CurveNode* curveNode = static_cast<CurveNode*>(node);
curveNode->animation = animation; // ✅ link from outside
}
// Add other node types here as needed
}
}
void loadNodeGraph(const uint8_t* packet, size_t length, NodeGraph& graph) {
size_t offset = 0;
// Read node count
uint16_t nodeCount = packet[offset];
offset += 1;
// Parse nodes
for (uint16_t i = 0; i < nodeCount; ++i) {
if (offset + 6 > length) break; // safety check
uint8_t type = packet[offset++];
uint8_t id = packet[offset++];
uint16_t x = packet[offset] | (packet[offset + 1] << 8);
offset += 2;
uint16_t y = packet[offset] | (packet[offset + 1] << 8);
offset += 2;
Node* node = nullptr;
switch (type) {
case TYPE_CURVENODE:
{ // CurveNode
if (offset + 1 > length) break;
uint8_t curveID = packet[offset++];
auto* curve = new CurveNode();
curve->id = id;
curve->type = type;
curve->x = x;
curve->y = y;
curve->curveID = curveID;
node = curve;
break;
}
case TYPE_SERVONODE:
{ // ServoNode
if (offset + 1 > length) break;
uint8_t motorID = packet[offset++];
auto* servo = new ServoNode();
servo->id = id;
servo->type = type;
servo->x = x;
servo->y = y;
servo->motorID = motorID;
node = servo;
break;
}
case TYPE_VARIABLENODE:
{ // ServoNode
if (offset + 1 > length) break;
uint8_t source = packet[offset++];
auto* newNode = new VariableNode();
newNode->id = id;
newNode->type = type;
newNode->x = x;
newNode->y = y;
newNode->source = static_cast<VariableSource>(source);
node = newNode;
break;
}
case TYPE_MATHNODE:
{
if (offset + 5 > length) break; // 1 byte op + 4 bytes float
uint8_t rawOp = packet[offset++];
float val;
memcpy(&val, &packet[offset], sizeof(float));
offset += sizeof(float);
auto* math = new MathNode();
math->id = id;
math->type = type;
math->x = x;
math->y = y;
math->op = static_cast<MathOperator>(rawOp);
math->value = val;
node = math;
break;
}
case TYPE_MAPNODE:
{
if (offset + 16 > length) break; // 4 x float32 = 16 bytes
float inMin, inMax, outMin, outMax;
memcpy(&inMin, &packet[offset], sizeof(float));
offset += 4;
memcpy(&inMax, &packet[offset], sizeof(float));
offset += 4;
memcpy(&outMin, &packet[offset], sizeof(float));
offset += 4;
memcpy(&outMax, &packet[offset], sizeof(float));
offset += 4;
auto* mapNode = new MapNode();
mapNode->id = id;
mapNode->type = type;
mapNode->x = x;
mapNode->y = y;
mapNode->inMin = inMin;
mapNode->inMax = inMax;
mapNode->outMin = outMin;
mapNode->outMax = outMax;
node = mapNode;
break;
}
case TYPE_NOISENODE:
{ // NoiseNode
if (offset + 17 > length) break;
offset += 17; // skip for now
break;
}
default:
break;
}
if (node) {
graph.nodes.push_back(node);
}
// Sort node list topologically for execution.
graph.nodes = graph.getSortedNodes();
}
// Parse connections
graph.connections.clear();
if (offset + 2 > length) return;
uint16_t connectionCount = packet[offset];
offset += 1;
for (uint16_t i = 0; i < connectionCount; ++i) {
if (offset + 2 > length) break;
uint8_t fromID = packet[offset++];
uint8_t toID = packet[offset++];
graph.connections.push_back({ fromID, toID });
}
}
String printNodeGraph(const NodeGraph& graph) {
String output = "📦 NodeGraph Dump\n";
output += "Nodes:\n";
for (const Node* node : graph.nodes) {
output += " ID " + String(node->id);
output += " | Type " + String(node->type);
output += " | Pos (" + String(node->x) + ", " + String(node->y) + ")";
switch (node->type) {
case TYPE_CURVENODE:
{ // CurveNode
const CurveNode* curve = static_cast<const CurveNode*>(node);
output += " | CurveID " + String(curve->curveID);
break;
}
case TYPE_SERVONODE:
{ // ServoNode
const ServoNode* servo = static_cast<const ServoNode*>(node);
output += " | MotorID " + String(servo->motorID);
break;
}
case TYPE_VARIABLENODE:
{ // VariableNode
const VariableNode* servo = static_cast<const VariableNode*>(node);
output += " | source: " + String(servo->source);
break;
}
// case 2: { // NoiseNode
// const NoiseNode* noise = static_cast<const NoiseNode*>(node);
// output += " | Noise a=" + String(noise->a, 2);
// output += " b=" + String(noise->b, 2);
// output += " c=" + String(noise->c, 2);
// output += " d=" + String(noise->d, 2);
// output += " seed=" + String(noise->seed);
// break;
// }
}
output += "\n";
}
output += "Connections:\n";
for (const auto& conn : graph.connections) {
output += " " + String(conn.fromID) + "" + String(conn.toID) + "\n";
}
return output;
}