PTA-L3-023-calculation diagram, and some of my thoughts on the calculation diagram

Keywords: Dynamic Programming number theory Math

Portal

Define dp1[u] the function value of the function composed of all points reachable by u sign; dp2[u] represents the partial derivative value of the function composed of all points reachable by u to the variable represented by varIdx, where varIdx is enumerated one by one outside the dp2 solution process and input.

Because we build an inverse graph, the function value is dp1[st], St represents the only point with zero penetration (the point with the smallest topological order) in the graph, and the partial derivative is dp2[st].

The problem stem says that calculating the function value is a process from left to right, but because I use memory search, I think it is more convenient to calculate the function value and partial derivative value from right to left (i.e. establish a reverse graph).

Since it is dp on DAG, this problem can also transfer the dp value in the process of topology sorting. If you use topological sorting, I think, for convenience, you should save both the original map and the inverse map. The transfer code looks the same, and you're too lazy to write

See the end of the text for the code.

The solution to this problem is very simple. But I want to record some thoughts on the calculation diagram (seeing the problem stem is the first time I heard). If you've done derivation of an expression“( I did derivation of expressions a long time ago ), you will know that this is not easy, let alone the partial derivative. But if we first use recursive descent to build a "calculation graph" (obviously, it is much simpler than splicing the derivative string directly by recursive descent), and then splice the partial derivative string on the calculation graph, it is simpler, because we decompose two simpler tasks.

Moreover, we can choose to use topological sorting to splice partial derivative strings without being bound by many layers of recursive functions.

#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i = (a);i <= (b);++i)
#define re_(i,a,b) for(int i = (a);i < (b);++i)
#define dwn(i,a,b) for(int i = (a);i >= (b);--i)

static const int N = 50005;

int n;int ideg[N],odeg[N];//1 = find the function value, 2 = find the partial derivative value
double dp1[N],dp2[N];bool vis1[N],vis2[N];
vector<int> G[N];int type[N];
vector<int> vars;

double dfs1(int u){
    double &d = dp1[u];
    if(vis1[u]) return d;
    vis1[u] = true;
    const int t = type[u];
    if(!t) return d;
    else if(t <= 3){
        double v1 = dfs1(G[u][0]),v2 = dfs1(G[u][1]);
        if(t == 1) return d = v1+v2;
        if(t == 2) return d = v1-v2;
        return d = v1*v2;
    }
    double v = dfs1(G[u][0]);
    if(t == 4) return d = exp(v);
    if(t == 5) return d = log(v);
    return d = sin(v);
}

double dfs2(int u,int varIdx){
    double &dp = dp2[u];
    if(vis2[u]) return dp;
    vis2[u] = true;
    const int t = type[u];
    if(!t) return dp = varIdx == u;
    else if(t <= 3){
        double v1 = dfs1(G[u][0]),v2 = dfs1(G[u][1]);
        double d1 = dfs2(G[u][0],varIdx),d2 = dfs2(G[u][1],varIdx);
        if(t == 1) return dp = d1+d2;
        if(t == 2) return dp = d1-d2;
        return dp = d1*v2+v1*d2;
    }
    double v = dfs1(G[u][0]);
    double d = dfs2(G[u][0],varIdx);
    if(t == 4) return dp = d*exp(v);
    if(t == 5) return dp = d/v;
    return dp = d*cos(v);
}

int main() {
    scanf("%d",&n);
    rep(i,1,n){
        int typ;scanf("%d",&typ);
        type[i] = typ;
        if(!typ){
            double v;scanf("%lf",&v);
            vars.push_back(i);
            dp1[i] = v;
        }
        else if(typ <= 3){
            int p1,p2;scanf("%d%d",&p1,&p2);++p1;++p2;
            G[i].push_back(p1);
            G[i].push_back(p2);
            ideg[p1]++;ideg[p2]++;odeg[i] += 2;
        }
        else{
            int p;scanf("%d",&p);++p;
            G[i].push_back(p);
            ideg[p]++;odeg[i]++;
        }
    }
    int st = -1;
    double ans1;
    rep(i,1,n) if(!ideg[i]){
        st = i;
        ans1 = dfs1(i);break;
    }
    vector<double> ans2;
    for(int var: vars){
        rep(i,1,n) vis2[i] = false;
        ans2.push_back(dfs2(st,var));
    }
    printf("%.3lf\n",ans1);
    re_(i,0,ans2.size()) printf("%.3lf%c",ans2[i]," \n"[i+1 == ans2.size()]);
    return 0;
}

Posted by stefandv on Mon, 20 Sep 2021 08:40:16 -0700